11"""Provides unit tests for the Tuner class."""
22
3+ from tempfile import TemporaryDirectory
4+ import typing as _t
35from unittest .mock import MagicMock , patch
46
57import msgspec
68import pytest
9+ import ray .tune
710
811from plugboard .schemas import ConfigSpec , ObjectiveSpec
912from plugboard .schemas .tune import (
@@ -23,6 +26,13 @@ def config() -> dict:
2326 return msgspec .yaml .decode (f .read ())
2427
2528
29+ @pytest .fixture
30+ def temp_dir () -> _t .Iterator [str ]:
31+ """Creates a temporary directory."""
32+ with TemporaryDirectory () as tmpdir :
33+ yield tmpdir
34+
35+
2636@patch ("ray.tune.Tuner" )
2737def test_tuner (mock_tuner_cls : MagicMock , config : dict ) -> None :
2838 """Test the Tuner class."""
@@ -91,64 +101,7 @@ def test_tuner(mock_tuner_cls: MagicMock, config: dict) -> None:
91101 mock_tuner .fit .assert_called_once ()
92102
93103
94- @patch ("ray.tune.Tuner" )
95- def test_tuner_with_optuna_storage (mock_tuner_cls : MagicMock , config : dict ) -> None :
96- """Test the Tuner class with Optuna storage URI."""
97- mock_tuner = MagicMock ()
98- mock_tuner_cls .return_value = mock_tuner
99-
100- spec = ConfigSpec .model_validate (config )
101- process_spec = spec .plugboard .process
102-
103- # Test with storage URI
104- optuna_spec = OptunaSpec (
105- type = "ray.tune.search.optuna.OptunaSearch" ,
106- study_name = "test-study" ,
107- storage = "sqlite:///test.db" ,
108- )
109-
110- tuner = Tuner (
111- objective = ObjectiveSpec (
112- object_type = "component" ,
113- object_name = "c" ,
114- field_type = "field" ,
115- field_name = "in_1" ,
116- ),
117- parameters = [
118- FloatParameterSpec (
119- object_type = "component" ,
120- object_name = "a" ,
121- field_type = "arg" ,
122- field_name = "y" ,
123- lower = 0.1 ,
124- upper = 0.5 ,
125- ),
126- ],
127- num_samples = 6 ,
128- mode = "max" ,
129- max_concurrent = 2 ,
130- algorithm = optuna_spec ,
131- )
132- tuner .run (spec = process_spec )
133-
134- # Must call the Tuner class with objective
135- assert callable (mock_tuner_cls .call_args .args [0 ])
136- # Must call the Tuner class with parameter space
137- kwargs = mock_tuner_cls .call_args .kwargs
138- param_space = kwargs ["param_space" ]
139- assert param_space ["a.y" ].__class__ .__name__ == "Float"
140- assert param_space ["a.y" ].lower == 0.1
141- assert param_space ["a.y" ].upper == 0.5
142- # Must call the Tuner class with configuration and correct algorithm
143- tune_config = kwargs ["tune_config" ]
144- assert tune_config .num_samples == 6
145- # Check searcher attribute as this contains the underlying algorithm with storage converted
146- assert tune_config .search_alg .searcher .__class__ .__name__ == "OptunaSearch"
147- # Must call fit method on the Tuner object
148- mock_tuner .fit .assert_called_once ()
149-
150-
151- def test_optuna_storage_uri_conversion () -> None :
104+ def test_optuna_storage_uri_conversion (temp_dir : str ) -> None :
152105 """Test that storage URI gets converted to Optuna storage object."""
153106 # Create a tuner with minimal configuration
154107 tuner = Tuner (
@@ -167,19 +120,11 @@ def test_optuna_storage_uri_conversion() -> None:
167120 ],
168121 num_samples = 1 ,
169122 mode = "max" ,
123+ algorithm = OptunaSpec (
124+ type = "ray.tune.search.optuna.OptunaSearch" ,
125+ study_name = "test-study" ,
126+ storage = f"sqlite:///{ temp_dir } /test_conversion.db" ,
127+ ),
170128 )
171-
172- # Test the _build_algorithm method with storage URI
173- optuna_spec = OptunaSpec (
174- type = "ray.tune.search.optuna.OptunaSearch" ,
175- study_name = "test-study" ,
176- storage = "sqlite:///test_conversion.db" ,
177- )
178-
179- # This should work without raising AssertionError
180- algorithm = tuner ._build_algorithm (optuna_spec )
181-
182- # Verify the algorithm was created successfully
183- import ray .tune .search .optuna
184-
185- assert isinstance (algorithm , ray .tune .search .optuna .OptunaSearch )
129+ algo = tuner ._config .search_alg
130+ assert isinstance (algo , ray .tune .search .optuna .OptunaSearch )
0 commit comments