Skip to content

Commit 6ad4320

Browse files
committed
add tests for storage
1 parent abf991e commit 6ad4320

1 file changed

Lines changed: 166 additions & 0 deletions

File tree

tests/test_main/test_storage.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""Tests for the pluggable storage system."""
2+
3+
import os
4+
import tempfile
5+
6+
import numpy as np
7+
import pytest
8+
9+
from gradient_free_optimizers import HillClimbingOptimizer, RandomSearchOptimizer
10+
from gradient_free_optimizers._result import Result
11+
from gradient_free_optimizers.storage import BaseStorage, MemoryStorage, SQLiteStorage
12+
13+
search_space = {"x": np.linspace(-10, 10, 100)}
14+
15+
16+
def objective(para):
17+
return -(para["x"] ** 2)
18+
19+
20+
class TestMemoryStorage:
21+
def test_get_put_contains(self):
22+
ms = MemoryStorage()
23+
assert not ms.contains((1, 2))
24+
assert ms.get((1, 2)) is None
25+
26+
ms.put((1, 2), Result(0.5, {}))
27+
assert ms.contains((1, 2))
28+
assert ms.get((1, 2)).score == 0.5
29+
30+
def test_len(self):
31+
ms = MemoryStorage()
32+
assert len(ms) == 0
33+
ms.put((1,), Result(0.1, {}))
34+
ms.put((2,), Result(0.2, {}))
35+
assert len(ms) == 2
36+
37+
def test_update(self):
38+
ms = MemoryStorage()
39+
ms.update({(1,): Result(0.1, {}), (2,): Result(0.2, {})})
40+
assert len(ms) == 2
41+
assert ms.get((2,)).score == 0.2
42+
43+
def test_items(self):
44+
ms = MemoryStorage()
45+
ms.put((1,), Result(0.1, {}))
46+
ms.put((2,), Result(0.2, {}))
47+
items = list(ms.items())
48+
assert len(items) == 2
49+
50+
def test_overwrite(self):
51+
ms = MemoryStorage()
52+
ms.put((1,), Result(0.1, {}))
53+
ms.put((1,), Result(0.9, {}))
54+
assert ms.get((1,)).score == 0.9
55+
assert len(ms) == 1
56+
57+
def test_isinstance(self):
58+
assert isinstance(MemoryStorage(), BaseStorage)
59+
60+
def test_with_search(self):
61+
storage = MemoryStorage()
62+
opt = HillClimbingOptimizer(search_space)
63+
opt.search(objective, n_iter=20, memory=storage, verbosity=False)
64+
assert len(storage) > 0
65+
assert opt.best_score is not None
66+
67+
68+
class TestSQLiteStorage:
69+
def test_get_put_contains(self):
70+
with tempfile.TemporaryDirectory() as td:
71+
ss = SQLiteStorage(os.path.join(td, "test.db"))
72+
assert not ss.contains((1, 2))
73+
assert ss.get((1, 2)) is None
74+
75+
ss.put((1, 2), Result(0.5, {"loss": 0.5}))
76+
assert ss.contains((1, 2))
77+
r = ss.get((1, 2))
78+
assert r.score == 0.5
79+
assert r.metrics == {"loss": 0.5}
80+
ss.close()
81+
82+
def test_persistence(self):
83+
with tempfile.TemporaryDirectory() as td:
84+
path = os.path.join(td, "persist.db")
85+
ss = SQLiteStorage(path)
86+
ss.put((1,), Result(0.1, {}))
87+
ss.put((2,), Result(0.2, {"m": 42}))
88+
ss.close()
89+
90+
ss2 = SQLiteStorage(path)
91+
assert len(ss2) == 2
92+
assert ss2.get((1,)).score == 0.1
93+
assert ss2.get((2,)).metrics == {"m": 42}
94+
ss2.close()
95+
96+
def test_bulk_update(self):
97+
ss = SQLiteStorage(":memory:")
98+
data = {(i,): Result(float(i), {}) for i in range(100)}
99+
ss.update(data)
100+
assert len(ss) == 100
101+
assert ss.get((50,)).score == 50.0
102+
ss.close()
103+
104+
def test_items_lazy(self):
105+
ss = SQLiteStorage(":memory:")
106+
for i in range(50):
107+
ss.put((i,), Result(float(i), {}))
108+
items = list(ss.items())
109+
assert len(items) == 50
110+
ss.close()
111+
112+
def test_isinstance(self):
113+
ss = SQLiteStorage(":memory:")
114+
assert isinstance(ss, BaseStorage)
115+
ss.close()
116+
117+
def test_with_search(self):
118+
with tempfile.TemporaryDirectory() as td:
119+
path = os.path.join(td, "search.db")
120+
storage = SQLiteStorage(path)
121+
opt = HillClimbingOptimizer(search_space)
122+
opt.search(objective, n_iter=20, memory=storage, verbosity=False)
123+
assert len(storage) > 0
124+
storage.close()
125+
126+
def test_crash_recovery(self):
127+
"""Second search with same storage skips cached positions."""
128+
with tempfile.TemporaryDirectory() as td:
129+
path = os.path.join(td, "recovery.db")
130+
storage = SQLiteStorage(path)
131+
132+
opt1 = RandomSearchOptimizer(search_space)
133+
opt1.search(objective, n_iter=20, memory=storage, verbosity=False)
134+
cached_after_first = len(storage)
135+
136+
opt2 = RandomSearchOptimizer(search_space)
137+
opt2.search(objective, n_iter=20, memory=storage, verbosity=False)
138+
cached_after_second = len(storage)
139+
140+
assert cached_after_second >= cached_after_first
141+
storage.close()
142+
143+
144+
class TestMemoryParameterTypes:
145+
def test_memory_true(self):
146+
opt = HillClimbingOptimizer(search_space)
147+
opt.search(objective, n_iter=10, memory=True, verbosity=False)
148+
assert opt.best_score is not None
149+
150+
def test_memory_false(self):
151+
opt = HillClimbingOptimizer(search_space)
152+
opt.search(objective, n_iter=10, memory=False, verbosity=False)
153+
assert opt.best_score is not None
154+
155+
def test_memory_storage_instance(self):
156+
storage = MemoryStorage()
157+
opt = HillClimbingOptimizer(search_space)
158+
opt.search(objective, n_iter=10, memory=storage, verbosity=False)
159+
assert len(storage) > 0
160+
161+
def test_memory_sqlite_instance(self):
162+
ss = SQLiteStorage(":memory:")
163+
opt = HillClimbingOptimizer(search_space)
164+
opt.search(objective, n_iter=10, memory=ss, verbosity=False)
165+
assert len(ss) > 0
166+
ss.close()

0 commit comments

Comments
 (0)