Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions test/python/test_xcsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import json
import os
import pickle
import numbers
import numpy as np
import pytest
from copy import deepcopy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import make_regression
Expand Down Expand Up @@ -289,3 +291,93 @@ def test_seeding(data):
# clean up
if os.path.exists(POP_FILENAME):
os.remove(POP_FILENAME)


def _compare_dicts(d1, d2, path=""):
diffs = []
all_keys = set(d1.keys()) | set(d2.keys())

for key in all_keys:
subpath = f"{path}.{key}" if path else key

if key not in d1 or key not in d2:
diffs.append((subpath, "Path exists in only one dict"))
continue

v1, v2 = d1[key], d2[key]

if isinstance(v1, dict) and isinstance(v2, dict):
diffs.extend(_compare_dicts(v1, v2, subpath))
elif isinstance(v1, list) and isinstance(v2, list):
if len(v1) != len(v2):
diffs.append((subpath, f"List length differs: {len(v1)} != {len(v2)}"))
for i, (x, y) in enumerate(zip(v1, v2)):
diffs.extend(_compare_dicts({0: x}, {0: y}, f"{subpath}[{i}]"))
elif isinstance(v1, numbers.Real) and isinstance(v2, numbers.Real):
if not np.isclose(v1, v2, atol=1e-10, rtol=0.0):
diffs.append((subpath, f"{v1} != {v2}"))
elif v1 != v2:
diffs.append((subpath, f"{v1} != {v2}"))

return diffs


def _test_pop_replace(tmp_path, pop_init, clean, fitinbetween, warm_start):
N = 500
DX = 3
X = np.random.random((N, DX))
y = np.random.randn(N, 1)

xcs = xcsf.XCS(x_dim=DX, pop_size=5, max_trials=1000, pop_init=pop_init)
xcs.fit(X, y, verbose=False)

# Initial, “too large” population.
json0 = xcs.json()
pop0 = json.loads(json0)

# “Pruning”.
pop1 = deepcopy(pop0)
del pop1["classifiers"][0]
json1 = json.dumps(pop1)
(tmp_path / "pset1.json").write_text(json1)

if fitinbetween:
xcs.fit(X, y, warm_start=True, verbose=False)

xcs.json_read(str(tmp_path / "pset1.json"), clean=clean)

# Pipe through `loads` b/c that was done above as well.
json2 = json.dumps(json.loads(xcs.json()))

list1 = json.loads(json1)["classifiers"]
list2 = json.loads(json2)["classifiers"]

if len(list1) != len(list2):
return False
else:
unequal = False
for cl1, cl2 in zip(list1, list2):
# If there is any difference, …
if _compare_dicts(cl1, cl2):
unequal = True
break
return not unequal


@pytest.mark.parametrize(
"pop_init,clean,fitinbetween,warm_start",
[
(False, True, False, False),
(False, True, True, False),
(False, True, True, True),
(True, True, False, False),
(True, True, True, False),
(True, True, True, True),
],
)
def test_pop_replace(tmp_path, pop_init, clean, fitinbetween, warm_start):
for seed in range(19):
np.random.seed(seed)
assert _test_pop_replace(
tmp_path, pop_init, clean, fitinbetween, warm_start
), f"failed at seed {seed}"
8 changes: 6 additions & 2 deletions xcsf/pybind_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,12 @@ class XCS
* @param [in] filename Name of the input file.
*/
void
json_read(const std::string &filename)
json_read(const std::string &filename, const bool clean)
{
if (clean) {
clset_kill(&xcs, &xcs.pset);
clset_init(&xcs.pset);
}
std::ifstream infile(filename);
std::stringstream buffer;
buffer << infile.rdbuf();
Expand Down Expand Up @@ -1105,7 +1109,7 @@ PYBIND11_MODULE(xcsf, m)
py::arg("filename"))
.def("json_read", &XCS::json_read,
"Reads classifiers from a JSON file and adds to the population.",
py::arg("filename"))
py::arg("filename"), py::arg("clean") = true)
.def("get_params", &XCS::get_params, py::arg("deep") = true,
"Returns a dictionary of parameters and their values.")
.def("set_params", &XCS::set_params, "Sets parameters.")
Expand Down
Loading