Skip to content

Commit 5af18d2

Browse files
committed
improve test coverage of distributed module
1 parent fdb98fd commit 5af18d2

1 file changed

Lines changed: 343 additions & 0 deletions

File tree

tests/test_main/test_distributed.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Tests for distributed evaluation backends and search loop integration."""
22

33
import concurrent.futures
4+
import os
45

56
import numpy as np
67
import pytest
78

9+
import gradient_free_optimizers._distributed._multiprocessing as _mp_module
810
from gradient_free_optimizers import (
911
BayesianOptimizer,
1012
DirectAlgorithm,
@@ -16,8 +18,10 @@
1618
)
1719
from gradient_free_optimizers._distributed import (
1820
BaseDistribution,
21+
Dask,
1922
Joblib,
2023
Multiprocessing,
24+
Ray,
2125
)
2226
from gradient_free_optimizers._storage import MemoryStorage
2327

@@ -666,3 +670,342 @@ def test_metric_columns_match_serial(self):
666670
assert serial_cols == dist_cols, f"Column mismatch: {serial_cols ^ dist_cols}"
667671
assert "loss" in dist_cols
668672
assert "x_abs" in dist_cols
673+
674+
675+
def _identity_x(para):
676+
"""Returns para["x"] directly. Module-level for multiprocessing pickling."""
677+
return para["x"]
678+
679+
680+
def _identity_x_with_metrics(para):
681+
"""Returns (score, metrics) tuple. Module-level for multiprocessing pickling."""
682+
return para["x"], {"doubled": para["x"] * 2}
683+
684+
685+
class TestBaseDistributionValidation:
686+
"""Unit tests for the BaseDistribution ABC contract."""
687+
688+
def _make_concrete(self):
689+
class Concrete(BaseDistribution):
690+
def _distribute(self, func, params_batch):
691+
return [func(p) for p in params_batch]
692+
693+
return Concrete
694+
695+
def test_n_workers_zero_raises(self):
696+
Concrete = self._make_concrete()
697+
with pytest.raises(ValueError, match="n_workers must be >= 1"):
698+
Concrete(n_workers=0)
699+
700+
def test_n_workers_negative_raises(self):
701+
Concrete = self._make_concrete()
702+
with pytest.raises(ValueError, match="n_workers must be >= 1"):
703+
Concrete(n_workers=-2)
704+
705+
def test_n_workers_one_accepted(self):
706+
Concrete = self._make_concrete()
707+
b = Concrete(n_workers=1)
708+
assert b.n_workers == 1
709+
710+
def test_n_workers_stored(self):
711+
Concrete = self._make_concrete()
712+
b = Concrete(n_workers=7)
713+
assert b.n_workers == 7
714+
715+
def test_is_async_default_false(self):
716+
assert BaseDistribution._is_async is False
717+
718+
def test_submit_raises_not_implemented(self):
719+
Concrete = self._make_concrete()
720+
b = Concrete(n_workers=1)
721+
with pytest.raises(NotImplementedError, match="Concrete"):
722+
b._submit(lambda x: x, {})
723+
724+
def test_wait_any_raises_not_implemented(self):
725+
Concrete = self._make_concrete()
726+
b = Concrete(n_workers=1)
727+
with pytest.raises(NotImplementedError, match="Concrete"):
728+
b._wait_any([])
729+
730+
731+
class TestDistributeWrapper:
732+
"""Tests for the distribute() decorator on BaseDistribution."""
733+
734+
def _make_backend(self, n_workers=3):
735+
class Serial(BaseDistribution):
736+
def _distribute(self, func, params_batch):
737+
return [func(p) for p in params_batch]
738+
739+
return Serial(n_workers=n_workers)
740+
741+
def test_gfo_distributed_flag(self):
742+
wrapped = self._make_backend().distribute(objective)
743+
assert wrapped._gfo_distributed is True
744+
745+
def test_gfo_batch_size_matches_n_workers(self):
746+
wrapped = self._make_backend(n_workers=5).distribute(objective)
747+
assert wrapped._gfo_batch_size == 5
748+
749+
def test_gfo_original_func_is_original(self):
750+
wrapped = self._make_backend().distribute(objective)
751+
assert wrapped._gfo_original_func is objective
752+
753+
def test_gfo_backend_is_backend_instance(self):
754+
backend = self._make_backend()
755+
wrapped = backend.distribute(objective)
756+
assert wrapped._gfo_backend is backend
757+
758+
def test_preserves_function_name(self):
759+
wrapped = self._make_backend().distribute(objective)
760+
assert wrapped.__name__ == "objective"
761+
762+
def test_preserves_lambda_name(self):
763+
fn = lambda para: para["x"] # noqa: E731
764+
wrapped = self._make_backend().distribute(fn)
765+
assert wrapped.__name__ == "<lambda>"
766+
767+
def test_wrapper_delegates_to_distribute(self):
768+
wrapped = self._make_backend().distribute(_identity_x)
769+
results = wrapped([{"x": 10}, {"x": 20}])
770+
assert results == [10, 20]
771+
772+
def test_wrapper_is_callable(self):
773+
wrapped = self._make_backend().distribute(objective)
774+
assert callable(wrapped)
775+
776+
777+
class TestMultiprocessingUnit:
778+
"""Unit tests for the Multiprocessing backend."""
779+
780+
def test_auto_detect_workers(self):
781+
mp = Multiprocessing(n_workers=-1)
782+
expected = os.cpu_count() or 1
783+
assert mp.n_workers == expected
784+
785+
def test_explicit_workers(self):
786+
mp = Multiprocessing(n_workers=3)
787+
assert mp.n_workers == 3
788+
789+
def test_prefers_fork_when_available(self):
790+
import multiprocessing
791+
792+
mp = Multiprocessing(n_workers=2)
793+
if "fork" in multiprocessing.get_all_start_methods():
794+
assert mp._use_fork is True
795+
else:
796+
assert mp._use_fork is False
797+
798+
def test_context_has_valid_start_method(self):
799+
mp = Multiprocessing(n_workers=2)
800+
assert mp._mp_context.get_start_method() in ("fork", "spawn", "forkserver")
801+
802+
def test_result_ordering(self):
803+
mp = Multiprocessing(n_workers=2)
804+
batch = [{"x": i} for i in range(10)]
805+
results = mp._distribute(_identity_x, batch)
806+
assert results == list(range(10))
807+
808+
def test_single_item_batch(self):
809+
mp = Multiprocessing(n_workers=2)
810+
results = mp._distribute(_identity_x, [{"x": 42}])
811+
assert results == [42]
812+
813+
def test_empty_batch(self):
814+
mp = Multiprocessing(n_workers=2)
815+
results = mp._distribute(_identity_x, [])
816+
assert results == []
817+
818+
def test_worker_func_cleaned_up_after_distribute(self):
819+
mp = Multiprocessing(n_workers=2)
820+
assert _mp_module._worker_func is None
821+
mp._distribute(_identity_x, [{"x": 1}])
822+
assert _mp_module._worker_func is None
823+
824+
def test_tuple_result_passthrough(self):
825+
mp = Multiprocessing(n_workers=2)
826+
batch = [{"x": 1}, {"x": 2}]
827+
results = mp._distribute(_identity_x_with_metrics, batch)
828+
assert results[0] == (1, {"doubled": 2})
829+
assert results[1] == (2, {"doubled": 4})
830+
831+
832+
class TestJoblibUnit:
833+
"""Unit tests for the Joblib backend."""
834+
835+
def test_auto_detect_workers(self):
836+
jl = Joblib(n_workers=-1)
837+
assert jl.n_workers >= 1
838+
839+
def test_explicit_workers(self):
840+
jl = Joblib(n_workers=3)
841+
assert jl.n_workers == 3
842+
843+
def test_default_backend_is_loky(self):
844+
jl = Joblib(n_workers=2)
845+
assert jl._backend_name == "loky"
846+
847+
def test_custom_backend_stored(self):
848+
jl = Joblib(n_workers=2, backend="threading")
849+
assert jl._backend_name == "threading"
850+
851+
def test_result_ordering(self):
852+
jl = Joblib(n_workers=2)
853+
batch = [{"x": i} for i in range(10)]
854+
results = jl._distribute(_identity_x, batch)
855+
assert results == list(range(10))
856+
857+
def test_single_item_batch(self):
858+
jl = Joblib(n_workers=2)
859+
results = jl._distribute(_identity_x, [{"x": 99}])
860+
assert results == [99]
861+
862+
def test_empty_batch(self):
863+
jl = Joblib(n_workers=2)
864+
results = jl._distribute(_identity_x, [])
865+
assert results == []
866+
867+
def test_threading_backend_produces_correct_results(self):
868+
jl = Joblib(n_workers=2, backend="threading")
869+
batch = [{"x": i} for i in range(5)]
870+
results = jl._distribute(_identity_x, batch)
871+
assert results == list(range(5))
872+
873+
def test_tuple_result_passthrough(self):
874+
jl = Joblib(n_workers=2)
875+
batch = [{"x": 1}, {"x": 2}]
876+
results = jl._distribute(_identity_x_with_metrics, batch)
877+
assert results[0] == (1, {"doubled": 2})
878+
assert results[1] == (2, {"doubled": 4})
879+
880+
881+
class TestRayUnit:
882+
"""Unit tests for the Ray backend."""
883+
884+
@pytest.fixture(autouse=True)
885+
def _ray_lifecycle(self):
886+
ray = pytest.importorskip("ray")
887+
ray.init(num_cpus=2, ignore_reinit_error=True)
888+
yield
889+
ray.shutdown()
890+
891+
def test_is_async(self):
892+
assert Ray._is_async is True
893+
894+
def test_remote_cache_initially_empty(self):
895+
r = Ray(n_workers=2)
896+
assert r._remote_cache == {}
897+
898+
def test_remote_caches_wrapper(self):
899+
r = Ray(n_workers=2)
900+
remote1 = r._remote(_identity_x)
901+
remote2 = r._remote(_identity_x)
902+
assert remote1 is remote2
903+
904+
def test_remote_different_funcs_get_separate_entries(self):
905+
r = Ray(n_workers=2)
906+
r._remote(_identity_x)
907+
r._remote(_identity_x_with_metrics)
908+
assert len(r._remote_cache) == 2
909+
910+
def test_result_ordering(self):
911+
r = Ray(n_workers=2)
912+
batch = [{"x": i} for i in range(8)]
913+
results = r._distribute(_identity_x, batch)
914+
assert results == list(range(8))
915+
916+
def test_single_item_batch(self):
917+
r = Ray(n_workers=1)
918+
results = r._distribute(_identity_x, [{"x": 7}])
919+
assert results == [7]
920+
921+
def test_empty_batch(self):
922+
r = Ray(n_workers=1)
923+
results = r._distribute(_identity_x, [])
924+
assert results == []
925+
926+
def test_submit_wait_roundtrip(self):
927+
r = Ray(n_workers=2)
928+
future = r._submit(_identity_x, {"x": 42})
929+
completed, result = r._wait_any([future])
930+
assert result == 42
931+
assert completed is future
932+
933+
def test_tuple_result_passthrough(self):
934+
r = Ray(n_workers=2)
935+
batch = [{"x": 1}, {"x": 2}]
936+
results = r._distribute(_identity_x_with_metrics, batch)
937+
assert results[0] == (1, {"doubled": 2})
938+
assert results[1] == (2, {"doubled": 4})
939+
940+
941+
class TestDaskUnit:
942+
"""Unit tests for the Dask backend."""
943+
944+
@pytest.fixture(autouse=True)
945+
def _require_dask(self):
946+
pytest.importorskip("dask.distributed")
947+
948+
@pytest.fixture
949+
def backend(self):
950+
b = Dask(n_workers=1)
951+
yield b
952+
if b._client is not None:
953+
b._client.close()
954+
955+
def test_is_async(self):
956+
assert Dask._is_async is True
957+
958+
def test_client_not_created_at_init(self):
959+
b = Dask(n_workers=2)
960+
assert b._client is None
961+
962+
def test_get_client_creates_on_first_call(self, backend):
963+
assert backend._client is None
964+
client = backend._get_client()
965+
assert client is not None
966+
assert backend._client is client
967+
968+
def test_get_client_returns_same_instance(self, backend):
969+
c1 = backend._get_client()
970+
c2 = backend._get_client()
971+
assert c1 is c2
972+
973+
def test_client_arg_reused(self):
974+
from dask.distributed import Client
975+
976+
external = Client(n_workers=1, threads_per_worker=1)
977+
try:
978+
b = Dask(n_workers=1, client=external)
979+
assert b._get_client() is external
980+
finally:
981+
external.close()
982+
983+
def test_address_and_client_stored(self):
984+
b = Dask(n_workers=2, address="tcp://localhost:9999")
985+
assert b._address == "tcp://localhost:9999"
986+
assert b._client_arg is None
987+
988+
def test_result_ordering(self, backend):
989+
batch = [{"x": i} for i in range(8)]
990+
results = backend._distribute(_identity_x, batch)
991+
assert results == list(range(8))
992+
993+
def test_single_item_batch(self, backend):
994+
results = backend._distribute(_identity_x, [{"x": 7}])
995+
assert results == [7]
996+
997+
def test_empty_batch(self, backend):
998+
results = backend._distribute(_identity_x, [])
999+
assert results == []
1000+
1001+
def test_submit_wait_roundtrip(self, backend):
1002+
future = backend._submit(_identity_x, {"x": 42})
1003+
completed, result = backend._wait_any([future])
1004+
assert result == 42
1005+
assert completed is future
1006+
1007+
def test_tuple_result_passthrough(self, backend):
1008+
batch = [{"x": 1}, {"x": 2}]
1009+
results = backend._distribute(_identity_x_with_metrics, batch)
1010+
assert results[0] == (1, {"doubled": 2})
1011+
assert results[1] == (2, {"doubled": 4})

0 commit comments

Comments
 (0)