|
1 | 1 | """Tests for distributed evaluation backends and search loop integration.""" |
2 | 2 |
|
3 | 3 | import concurrent.futures |
| 4 | +import os |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import pytest |
7 | 8 |
|
| 9 | +import gradient_free_optimizers._distributed._multiprocessing as _mp_module |
8 | 10 | from gradient_free_optimizers import ( |
9 | 11 | BayesianOptimizer, |
10 | 12 | DirectAlgorithm, |
|
16 | 18 | ) |
17 | 19 | from gradient_free_optimizers._distributed import ( |
18 | 20 | BaseDistribution, |
| 21 | + Dask, |
19 | 22 | Joblib, |
20 | 23 | Multiprocessing, |
| 24 | + Ray, |
21 | 25 | ) |
22 | 26 | from gradient_free_optimizers._storage import MemoryStorage |
23 | 27 |
|
@@ -666,3 +670,342 @@ def test_metric_columns_match_serial(self): |
666 | 670 | assert serial_cols == dist_cols, f"Column mismatch: {serial_cols ^ dist_cols}" |
667 | 671 | assert "loss" in dist_cols |
668 | 672 | assert "x_abs" in dist_cols |
| 673 | + |
| 674 | + |
| 675 | +def _identity_x(para): |
| 676 | + """Returns para["x"] directly. Module-level for multiprocessing pickling.""" |
| 677 | + return para["x"] |
| 678 | + |
| 679 | + |
| 680 | +def _identity_x_with_metrics(para): |
| 681 | + """Returns (score, metrics) tuple. Module-level for multiprocessing pickling.""" |
| 682 | + return para["x"], {"doubled": para["x"] * 2} |
| 683 | + |
| 684 | + |
| 685 | +class TestBaseDistributionValidation: |
| 686 | + """Unit tests for the BaseDistribution ABC contract.""" |
| 687 | + |
| 688 | + def _make_concrete(self): |
| 689 | + class Concrete(BaseDistribution): |
| 690 | + def _distribute(self, func, params_batch): |
| 691 | + return [func(p) for p in params_batch] |
| 692 | + |
| 693 | + return Concrete |
| 694 | + |
| 695 | + def test_n_workers_zero_raises(self): |
| 696 | + Concrete = self._make_concrete() |
| 697 | + with pytest.raises(ValueError, match="n_workers must be >= 1"): |
| 698 | + Concrete(n_workers=0) |
| 699 | + |
| 700 | + def test_n_workers_negative_raises(self): |
| 701 | + Concrete = self._make_concrete() |
| 702 | + with pytest.raises(ValueError, match="n_workers must be >= 1"): |
| 703 | + Concrete(n_workers=-2) |
| 704 | + |
| 705 | + def test_n_workers_one_accepted(self): |
| 706 | + Concrete = self._make_concrete() |
| 707 | + b = Concrete(n_workers=1) |
| 708 | + assert b.n_workers == 1 |
| 709 | + |
| 710 | + def test_n_workers_stored(self): |
| 711 | + Concrete = self._make_concrete() |
| 712 | + b = Concrete(n_workers=7) |
| 713 | + assert b.n_workers == 7 |
| 714 | + |
| 715 | + def test_is_async_default_false(self): |
| 716 | + assert BaseDistribution._is_async is False |
| 717 | + |
| 718 | + def test_submit_raises_not_implemented(self): |
| 719 | + Concrete = self._make_concrete() |
| 720 | + b = Concrete(n_workers=1) |
| 721 | + with pytest.raises(NotImplementedError, match="Concrete"): |
| 722 | + b._submit(lambda x: x, {}) |
| 723 | + |
| 724 | + def test_wait_any_raises_not_implemented(self): |
| 725 | + Concrete = self._make_concrete() |
| 726 | + b = Concrete(n_workers=1) |
| 727 | + with pytest.raises(NotImplementedError, match="Concrete"): |
| 728 | + b._wait_any([]) |
| 729 | + |
| 730 | + |
| 731 | +class TestDistributeWrapper: |
| 732 | + """Tests for the distribute() decorator on BaseDistribution.""" |
| 733 | + |
| 734 | + def _make_backend(self, n_workers=3): |
| 735 | + class Serial(BaseDistribution): |
| 736 | + def _distribute(self, func, params_batch): |
| 737 | + return [func(p) for p in params_batch] |
| 738 | + |
| 739 | + return Serial(n_workers=n_workers) |
| 740 | + |
| 741 | + def test_gfo_distributed_flag(self): |
| 742 | + wrapped = self._make_backend().distribute(objective) |
| 743 | + assert wrapped._gfo_distributed is True |
| 744 | + |
| 745 | + def test_gfo_batch_size_matches_n_workers(self): |
| 746 | + wrapped = self._make_backend(n_workers=5).distribute(objective) |
| 747 | + assert wrapped._gfo_batch_size == 5 |
| 748 | + |
| 749 | + def test_gfo_original_func_is_original(self): |
| 750 | + wrapped = self._make_backend().distribute(objective) |
| 751 | + assert wrapped._gfo_original_func is objective |
| 752 | + |
| 753 | + def test_gfo_backend_is_backend_instance(self): |
| 754 | + backend = self._make_backend() |
| 755 | + wrapped = backend.distribute(objective) |
| 756 | + assert wrapped._gfo_backend is backend |
| 757 | + |
| 758 | + def test_preserves_function_name(self): |
| 759 | + wrapped = self._make_backend().distribute(objective) |
| 760 | + assert wrapped.__name__ == "objective" |
| 761 | + |
| 762 | + def test_preserves_lambda_name(self): |
| 763 | + fn = lambda para: para["x"] # noqa: E731 |
| 764 | + wrapped = self._make_backend().distribute(fn) |
| 765 | + assert wrapped.__name__ == "<lambda>" |
| 766 | + |
| 767 | + def test_wrapper_delegates_to_distribute(self): |
| 768 | + wrapped = self._make_backend().distribute(_identity_x) |
| 769 | + results = wrapped([{"x": 10}, {"x": 20}]) |
| 770 | + assert results == [10, 20] |
| 771 | + |
| 772 | + def test_wrapper_is_callable(self): |
| 773 | + wrapped = self._make_backend().distribute(objective) |
| 774 | + assert callable(wrapped) |
| 775 | + |
| 776 | + |
| 777 | +class TestMultiprocessingUnit: |
| 778 | + """Unit tests for the Multiprocessing backend.""" |
| 779 | + |
| 780 | + def test_auto_detect_workers(self): |
| 781 | + mp = Multiprocessing(n_workers=-1) |
| 782 | + expected = os.cpu_count() or 1 |
| 783 | + assert mp.n_workers == expected |
| 784 | + |
| 785 | + def test_explicit_workers(self): |
| 786 | + mp = Multiprocessing(n_workers=3) |
| 787 | + assert mp.n_workers == 3 |
| 788 | + |
| 789 | + def test_prefers_fork_when_available(self): |
| 790 | + import multiprocessing |
| 791 | + |
| 792 | + mp = Multiprocessing(n_workers=2) |
| 793 | + if "fork" in multiprocessing.get_all_start_methods(): |
| 794 | + assert mp._use_fork is True |
| 795 | + else: |
| 796 | + assert mp._use_fork is False |
| 797 | + |
| 798 | + def test_context_has_valid_start_method(self): |
| 799 | + mp = Multiprocessing(n_workers=2) |
| 800 | + assert mp._mp_context.get_start_method() in ("fork", "spawn", "forkserver") |
| 801 | + |
| 802 | + def test_result_ordering(self): |
| 803 | + mp = Multiprocessing(n_workers=2) |
| 804 | + batch = [{"x": i} for i in range(10)] |
| 805 | + results = mp._distribute(_identity_x, batch) |
| 806 | + assert results == list(range(10)) |
| 807 | + |
| 808 | + def test_single_item_batch(self): |
| 809 | + mp = Multiprocessing(n_workers=2) |
| 810 | + results = mp._distribute(_identity_x, [{"x": 42}]) |
| 811 | + assert results == [42] |
| 812 | + |
| 813 | + def test_empty_batch(self): |
| 814 | + mp = Multiprocessing(n_workers=2) |
| 815 | + results = mp._distribute(_identity_x, []) |
| 816 | + assert results == [] |
| 817 | + |
| 818 | + def test_worker_func_cleaned_up_after_distribute(self): |
| 819 | + mp = Multiprocessing(n_workers=2) |
| 820 | + assert _mp_module._worker_func is None |
| 821 | + mp._distribute(_identity_x, [{"x": 1}]) |
| 822 | + assert _mp_module._worker_func is None |
| 823 | + |
| 824 | + def test_tuple_result_passthrough(self): |
| 825 | + mp = Multiprocessing(n_workers=2) |
| 826 | + batch = [{"x": 1}, {"x": 2}] |
| 827 | + results = mp._distribute(_identity_x_with_metrics, batch) |
| 828 | + assert results[0] == (1, {"doubled": 2}) |
| 829 | + assert results[1] == (2, {"doubled": 4}) |
| 830 | + |
| 831 | + |
| 832 | +class TestJoblibUnit: |
| 833 | + """Unit tests for the Joblib backend.""" |
| 834 | + |
| 835 | + def test_auto_detect_workers(self): |
| 836 | + jl = Joblib(n_workers=-1) |
| 837 | + assert jl.n_workers >= 1 |
| 838 | + |
| 839 | + def test_explicit_workers(self): |
| 840 | + jl = Joblib(n_workers=3) |
| 841 | + assert jl.n_workers == 3 |
| 842 | + |
| 843 | + def test_default_backend_is_loky(self): |
| 844 | + jl = Joblib(n_workers=2) |
| 845 | + assert jl._backend_name == "loky" |
| 846 | + |
| 847 | + def test_custom_backend_stored(self): |
| 848 | + jl = Joblib(n_workers=2, backend="threading") |
| 849 | + assert jl._backend_name == "threading" |
| 850 | + |
| 851 | + def test_result_ordering(self): |
| 852 | + jl = Joblib(n_workers=2) |
| 853 | + batch = [{"x": i} for i in range(10)] |
| 854 | + results = jl._distribute(_identity_x, batch) |
| 855 | + assert results == list(range(10)) |
| 856 | + |
| 857 | + def test_single_item_batch(self): |
| 858 | + jl = Joblib(n_workers=2) |
| 859 | + results = jl._distribute(_identity_x, [{"x": 99}]) |
| 860 | + assert results == [99] |
| 861 | + |
| 862 | + def test_empty_batch(self): |
| 863 | + jl = Joblib(n_workers=2) |
| 864 | + results = jl._distribute(_identity_x, []) |
| 865 | + assert results == [] |
| 866 | + |
| 867 | + def test_threading_backend_produces_correct_results(self): |
| 868 | + jl = Joblib(n_workers=2, backend="threading") |
| 869 | + batch = [{"x": i} for i in range(5)] |
| 870 | + results = jl._distribute(_identity_x, batch) |
| 871 | + assert results == list(range(5)) |
| 872 | + |
| 873 | + def test_tuple_result_passthrough(self): |
| 874 | + jl = Joblib(n_workers=2) |
| 875 | + batch = [{"x": 1}, {"x": 2}] |
| 876 | + results = jl._distribute(_identity_x_with_metrics, batch) |
| 877 | + assert results[0] == (1, {"doubled": 2}) |
| 878 | + assert results[1] == (2, {"doubled": 4}) |
| 879 | + |
| 880 | + |
| 881 | +class TestRayUnit: |
| 882 | + """Unit tests for the Ray backend.""" |
| 883 | + |
| 884 | + @pytest.fixture(autouse=True) |
| 885 | + def _ray_lifecycle(self): |
| 886 | + ray = pytest.importorskip("ray") |
| 887 | + ray.init(num_cpus=2, ignore_reinit_error=True) |
| 888 | + yield |
| 889 | + ray.shutdown() |
| 890 | + |
| 891 | + def test_is_async(self): |
| 892 | + assert Ray._is_async is True |
| 893 | + |
| 894 | + def test_remote_cache_initially_empty(self): |
| 895 | + r = Ray(n_workers=2) |
| 896 | + assert r._remote_cache == {} |
| 897 | + |
| 898 | + def test_remote_caches_wrapper(self): |
| 899 | + r = Ray(n_workers=2) |
| 900 | + remote1 = r._remote(_identity_x) |
| 901 | + remote2 = r._remote(_identity_x) |
| 902 | + assert remote1 is remote2 |
| 903 | + |
| 904 | + def test_remote_different_funcs_get_separate_entries(self): |
| 905 | + r = Ray(n_workers=2) |
| 906 | + r._remote(_identity_x) |
| 907 | + r._remote(_identity_x_with_metrics) |
| 908 | + assert len(r._remote_cache) == 2 |
| 909 | + |
| 910 | + def test_result_ordering(self): |
| 911 | + r = Ray(n_workers=2) |
| 912 | + batch = [{"x": i} for i in range(8)] |
| 913 | + results = r._distribute(_identity_x, batch) |
| 914 | + assert results == list(range(8)) |
| 915 | + |
| 916 | + def test_single_item_batch(self): |
| 917 | + r = Ray(n_workers=1) |
| 918 | + results = r._distribute(_identity_x, [{"x": 7}]) |
| 919 | + assert results == [7] |
| 920 | + |
| 921 | + def test_empty_batch(self): |
| 922 | + r = Ray(n_workers=1) |
| 923 | + results = r._distribute(_identity_x, []) |
| 924 | + assert results == [] |
| 925 | + |
| 926 | + def test_submit_wait_roundtrip(self): |
| 927 | + r = Ray(n_workers=2) |
| 928 | + future = r._submit(_identity_x, {"x": 42}) |
| 929 | + completed, result = r._wait_any([future]) |
| 930 | + assert result == 42 |
| 931 | + assert completed is future |
| 932 | + |
| 933 | + def test_tuple_result_passthrough(self): |
| 934 | + r = Ray(n_workers=2) |
| 935 | + batch = [{"x": 1}, {"x": 2}] |
| 936 | + results = r._distribute(_identity_x_with_metrics, batch) |
| 937 | + assert results[0] == (1, {"doubled": 2}) |
| 938 | + assert results[1] == (2, {"doubled": 4}) |
| 939 | + |
| 940 | + |
| 941 | +class TestDaskUnit: |
| 942 | + """Unit tests for the Dask backend.""" |
| 943 | + |
| 944 | + @pytest.fixture(autouse=True) |
| 945 | + def _require_dask(self): |
| 946 | + pytest.importorskip("dask.distributed") |
| 947 | + |
| 948 | + @pytest.fixture |
| 949 | + def backend(self): |
| 950 | + b = Dask(n_workers=1) |
| 951 | + yield b |
| 952 | + if b._client is not None: |
| 953 | + b._client.close() |
| 954 | + |
| 955 | + def test_is_async(self): |
| 956 | + assert Dask._is_async is True |
| 957 | + |
| 958 | + def test_client_not_created_at_init(self): |
| 959 | + b = Dask(n_workers=2) |
| 960 | + assert b._client is None |
| 961 | + |
| 962 | + def test_get_client_creates_on_first_call(self, backend): |
| 963 | + assert backend._client is None |
| 964 | + client = backend._get_client() |
| 965 | + assert client is not None |
| 966 | + assert backend._client is client |
| 967 | + |
| 968 | + def test_get_client_returns_same_instance(self, backend): |
| 969 | + c1 = backend._get_client() |
| 970 | + c2 = backend._get_client() |
| 971 | + assert c1 is c2 |
| 972 | + |
| 973 | + def test_client_arg_reused(self): |
| 974 | + from dask.distributed import Client |
| 975 | + |
| 976 | + external = Client(n_workers=1, threads_per_worker=1) |
| 977 | + try: |
| 978 | + b = Dask(n_workers=1, client=external) |
| 979 | + assert b._get_client() is external |
| 980 | + finally: |
| 981 | + external.close() |
| 982 | + |
| 983 | + def test_address_and_client_stored(self): |
| 984 | + b = Dask(n_workers=2, address="tcp://localhost:9999") |
| 985 | + assert b._address == "tcp://localhost:9999" |
| 986 | + assert b._client_arg is None |
| 987 | + |
| 988 | + def test_result_ordering(self, backend): |
| 989 | + batch = [{"x": i} for i in range(8)] |
| 990 | + results = backend._distribute(_identity_x, batch) |
| 991 | + assert results == list(range(8)) |
| 992 | + |
| 993 | + def test_single_item_batch(self, backend): |
| 994 | + results = backend._distribute(_identity_x, [{"x": 7}]) |
| 995 | + assert results == [7] |
| 996 | + |
| 997 | + def test_empty_batch(self, backend): |
| 998 | + results = backend._distribute(_identity_x, []) |
| 999 | + assert results == [] |
| 1000 | + |
| 1001 | + def test_submit_wait_roundtrip(self, backend): |
| 1002 | + future = backend._submit(_identity_x, {"x": 42}) |
| 1003 | + completed, result = backend._wait_any([future]) |
| 1004 | + assert result == 42 |
| 1005 | + assert completed is future |
| 1006 | + |
| 1007 | + def test_tuple_result_passthrough(self, backend): |
| 1008 | + batch = [{"x": 1}, {"x": 2}] |
| 1009 | + results = backend._distribute(_identity_x_with_metrics, batch) |
| 1010 | + assert results[0] == (1, {"doubled": 2}) |
| 1011 | + assert results[1] == (2, {"doubled": 4}) |
0 commit comments