Skip to content

Commit 34e5594

Browse files
committed
Update tests
1 parent 027ceba commit 34e5594

1 file changed

Lines changed: 18 additions & 73 deletions

File tree

tests/unit/test_tuner.py

Lines changed: 18 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Provides unit tests for the Tuner class."""
22

3+
from tempfile import TemporaryDirectory
4+
import typing as _t
35
from unittest.mock import MagicMock, patch
46

57
import msgspec
68
import pytest
9+
import ray.tune
710

811
from plugboard.schemas import ConfigSpec, ObjectiveSpec
912
from plugboard.schemas.tune import (
@@ -23,6 +26,13 @@ def config() -> dict:
2326
return msgspec.yaml.decode(f.read())
2427

2528

29+
@pytest.fixture
30+
def temp_dir() -> _t.Iterator[str]:
31+
"""Creates a temporary directory."""
32+
with TemporaryDirectory() as tmpdir:
33+
yield tmpdir
34+
35+
2636
@patch("ray.tune.Tuner")
2737
def test_tuner(mock_tuner_cls: MagicMock, config: dict) -> None:
2838
"""Test the Tuner class."""
@@ -91,64 +101,7 @@ def test_tuner(mock_tuner_cls: MagicMock, config: dict) -> None:
91101
mock_tuner.fit.assert_called_once()
92102

93103

94-
@patch("ray.tune.Tuner")
95-
def test_tuner_with_optuna_storage(mock_tuner_cls: MagicMock, config: dict) -> None:
96-
"""Test the Tuner class with Optuna storage URI."""
97-
mock_tuner = MagicMock()
98-
mock_tuner_cls.return_value = mock_tuner
99-
100-
spec = ConfigSpec.model_validate(config)
101-
process_spec = spec.plugboard.process
102-
103-
# Test with storage URI
104-
optuna_spec = OptunaSpec(
105-
type="ray.tune.search.optuna.OptunaSearch",
106-
study_name="test-study",
107-
storage="sqlite:///test.db",
108-
)
109-
110-
tuner = Tuner(
111-
objective=ObjectiveSpec(
112-
object_type="component",
113-
object_name="c",
114-
field_type="field",
115-
field_name="in_1",
116-
),
117-
parameters=[
118-
FloatParameterSpec(
119-
object_type="component",
120-
object_name="a",
121-
field_type="arg",
122-
field_name="y",
123-
lower=0.1,
124-
upper=0.5,
125-
),
126-
],
127-
num_samples=6,
128-
mode="max",
129-
max_concurrent=2,
130-
algorithm=optuna_spec,
131-
)
132-
tuner.run(spec=process_spec)
133-
134-
# Must call the Tuner class with objective
135-
assert callable(mock_tuner_cls.call_args.args[0])
136-
# Must call the Tuner class with parameter space
137-
kwargs = mock_tuner_cls.call_args.kwargs
138-
param_space = kwargs["param_space"]
139-
assert param_space["a.y"].__class__.__name__ == "Float"
140-
assert param_space["a.y"].lower == 0.1
141-
assert param_space["a.y"].upper == 0.5
142-
# Must call the Tuner class with configuration and correct algorithm
143-
tune_config = kwargs["tune_config"]
144-
assert tune_config.num_samples == 6
145-
# Check searcher attribute as this contains the underlying algorithm with storage converted
146-
assert tune_config.search_alg.searcher.__class__.__name__ == "OptunaSearch"
147-
# Must call fit method on the Tuner object
148-
mock_tuner.fit.assert_called_once()
149-
150-
151-
def test_optuna_storage_uri_conversion() -> None:
104+
def test_optuna_storage_uri_conversion(temp_dir: str) -> None:
152105
"""Test that storage URI gets converted to Optuna storage object."""
153106
# Create a tuner with minimal configuration
154107
tuner = Tuner(
@@ -167,19 +120,11 @@ def test_optuna_storage_uri_conversion() -> None:
167120
],
168121
num_samples=1,
169122
mode="max",
123+
algorithm=OptunaSpec(
124+
type="ray.tune.search.optuna.OptunaSearch",
125+
study_name="test-study",
126+
storage=f"sqlite:///{temp_dir}/test_conversion.db",
127+
),
170128
)
171-
172-
# Test the _build_algorithm method with storage URI
173-
optuna_spec = OptunaSpec(
174-
type="ray.tune.search.optuna.OptunaSearch",
175-
study_name="test-study",
176-
storage="sqlite:///test_conversion.db",
177-
)
178-
179-
# This should work without raising AssertionError
180-
algorithm = tuner._build_algorithm(optuna_spec)
181-
182-
# Verify the algorithm was created successfully
183-
import ray.tune.search.optuna
184-
185-
assert isinstance(algorithm, ray.tune.search.optuna.OptunaSearch)
129+
algo = tuner._config.search_alg
130+
assert isinstance(algo, ray.tune.search.optuna.OptunaSearch)

0 commit comments

Comments
 (0)