diff --git a/test/python/test_xcsf.py b/test/python/test_xcsf.py index 0d3818ace..15fd1a0f9 100644 --- a/test/python/test_xcsf.py +++ b/test/python/test_xcsf.py @@ -25,8 +25,10 @@ import json import os import pickle +import numbers import numpy as np import pytest +from copy import deepcopy from sklearn.model_selection import train_test_split from sklearn.preprocessing import MinMaxScaler from sklearn.datasets import make_regression @@ -289,3 +291,93 @@ def test_seeding(data): # clean up if os.path.exists(POP_FILENAME): os.remove(POP_FILENAME) + + +def _compare_dicts(d1, d2, path=""): + diffs = [] + all_keys = set(d1.keys()) | set(d2.keys()) + + for key in all_keys: + subpath = f"{path}.{key}" if path else key + + if key not in d1 or key not in d2: + diffs.append((subpath, "Path exists in only one dict")) + continue + + v1, v2 = d1[key], d2[key] + + if isinstance(v1, dict) and isinstance(v2, dict): + diffs.extend(_compare_dicts(v1, v2, subpath)) + elif isinstance(v1, list) and isinstance(v2, list): + if len(v1) != len(v2): + diffs.append((subpath, f"List length differs: {len(v1)} != {len(v2)}")) + for i, (x, y) in enumerate(zip(v1, v2)): + diffs.extend(_compare_dicts({0: x}, {0: y}, f"{subpath}[{i}]")) + elif isinstance(v1, numbers.Real) and isinstance(v2, numbers.Real): + if not np.isclose(v1, v2, atol=1e-10, rtol=0.0): + diffs.append((subpath, f"{v1} != {v2}")) + elif v1 != v2: + diffs.append((subpath, f"{v1} != {v2}")) + + return diffs + + +def _test_pop_replace(tmp_path, pop_init, clean, fitinbetween, warm_start): + N = 500 + DX = 3 + X = np.random.random((N, DX)) + y = np.random.randn(N, 1) + + xcs = xcsf.XCS(x_dim=DX, pop_size=5, max_trials=1000, pop_init=pop_init) + xcs.fit(X, y, verbose=False) + + # Initial, “too large” population. + json0 = xcs.json() + pop0 = json.loads(json0) + + # “Pruning”. + pop1 = deepcopy(pop0) + del pop1["classifiers"][0] + json1 = json.dumps(pop1) + (tmp_path / "pset1.json").write_text(json1) + + if fitinbetween: + xcs.fit(X, y, warm_start=True, verbose=False) + + xcs.json_read(str(tmp_path / "pset1.json"), clean=clean) + + # Pipe through `loads` b/c that was done above as well. + json2 = json.dumps(json.loads(xcs.json())) + + list1 = json.loads(json1)["classifiers"] + list2 = json.loads(json2)["classifiers"] + + if len(list1) != len(list2): + return False + else: + unequal = False + for cl1, cl2 in zip(list1, list2): + # If there is any difference, … + if _compare_dicts(cl1, cl2): + unequal = True + break + return not unequal + + +@pytest.mark.parametrize( + "pop_init,clean,fitinbetween,warm_start", + [ + (False, True, False, False), + (False, True, True, False), + (False, True, True, True), + (True, True, False, False), + (True, True, True, False), + (True, True, True, True), + ], +) +def test_pop_replace(tmp_path, pop_init, clean, fitinbetween, warm_start): + for seed in range(19): + np.random.seed(seed) + assert _test_pop_replace( + tmp_path, pop_init, clean, fitinbetween, warm_start + ), f"failed at seed {seed}" diff --git a/xcsf/pybind_wrapper.cpp b/xcsf/pybind_wrapper.cpp index 6898372b7..aee42bc31 100644 --- a/xcsf/pybind_wrapper.cpp +++ b/xcsf/pybind_wrapper.cpp @@ -943,8 +943,12 @@ class XCS * @param [in] filename Name of the input file. */ void - json_read(const std::string &filename) + json_read(const std::string &filename, const bool clean) { + if (clean) { + clset_kill(&xcs, &xcs.pset); + clset_init(&xcs.pset); + } std::ifstream infile(filename); std::stringstream buffer; buffer << infile.rdbuf(); @@ -1105,7 +1109,7 @@ PYBIND11_MODULE(xcsf, m) py::arg("filename")) .def("json_read", &XCS::json_read, "Reads classifiers from a JSON file and adds to the population.", - py::arg("filename")) + py::arg("filename"), py::arg("clean") = true) .def("get_params", &XCS::get_params, py::arg("deep") = true, "Returns a dictionary of parameters and their values.") .def("set_params", &XCS::set_params, "Sets parameters.")