|
1 | 1 | """Tests for input_stream default behavior.""" |
2 | 2 |
|
3 | | -import unittest |
4 | | -from unittest import mock |
5 | | - |
6 | | -import numpy as np |
7 | | - |
8 | 3 | from pysr import PySRRegressor |
9 | | -from pysr.sr import ALREADY_RAN |
10 | | - |
11 | | - |
12 | | -class TestInputStream(unittest.TestCase): |
13 | | - """Verify input_stream defaults and backend passthrough.""" |
14 | | - |
15 | | - def test_default_is_none(self): |
16 | | - """By default, input_stream should be None so the Julia backend picks the stream.""" |
17 | | - model = PySRRegressor() |
18 | | - self.assertIsNone(model.input_stream) |
19 | | - |
20 | | - def test_explicit_stdin(self): |
21 | | - model = PySRRegressor(input_stream="stdin") |
22 | | - self.assertEqual(model.input_stream, "stdin") |
23 | | - |
24 | | - def test_explicit_devnull(self): |
25 | | - model = PySRRegressor(input_stream="devnull") |
26 | | - self.assertEqual(model.input_stream, "devnull") |
27 | | - |
28 | | - def _make_mock_jl(self): |
29 | | - """Return a MagicMock that satisfies the Julia calls made before Options().""" |
30 | | - m = mock.MagicMock() |
31 | | - m.seval.return_value = mock.MagicMock() |
32 | | - m.Dict.return_value = mock.MagicMock() |
33 | | - m.Pair.return_value = mock.MagicMock() |
34 | | - m.Symbol.return_value = mock.MagicMock() |
35 | | - m.NamedTuple.return_value = mock.MagicMock() |
36 | | - return m |
37 | | - |
38 | | - def _mocked_fit(self, model, X, y, *, capture_options): |
39 | | - """Run model.fit with enough Julia infrastructure mocked to reach Options().""" |
40 | | - mock_jl = self._make_mock_jl() |
41 | | - sr = mock.MagicMock() |
42 | | - sr.MutationWeights.return_value = mock.MagicMock() |
43 | | - sr.Options.side_effect = capture_options |
44 | | - sr.equation_search.return_value = (mock.MagicMock(), b"mock") |
45 | | - sr.SearchUtilsModule.generate_run_id.return_value = "test-run-id" |
46 | | - |
47 | | - # Reset global so _run doesn't skip its first-run block. |
48 | | - import pysr.sr as sr_module |
49 | | - |
50 | | - old_already_ran = sr_module.ALREADY_RAN |
51 | | - sr_module.ALREADY_RAN = False |
52 | | - |
53 | | - try: |
54 | | - with mock.patch("pysr.sr.jl", mock_jl): |
55 | | - with mock.patch("pysr.sr.SymbolicRegression", sr): |
56 | | - with mock.patch("pysr.sr.jl_array", return_value=mock.MagicMock()): |
57 | | - with mock.patch("pysr.sr.jl_is_function", return_value=True): |
58 | | - with mock.patch( |
59 | | - "pysr.sr.jl_serialize", return_value=b"mock" |
60 | | - ): |
61 | | - with mock.patch("pysr.sr.load_required_packages"): |
62 | | - with mock.patch("pysr.sr._load_cluster_manager"): |
63 | | - model.fit(X, y) |
64 | | - finally: |
65 | | - sr_module.ALREADY_RAN = old_already_ran |
66 | | - |
67 | | - def test_default_omits_input_stream_from_options(self): |
68 | | - """When input_stream is None, Options should not receive the kwarg.""" |
69 | | - captured_kwargs = {} |
70 | | - |
71 | | - def capture_options(**kwargs): |
72 | | - captured_kwargs.update(kwargs) |
73 | | - raise RuntimeError("stop_after_options") |
74 | | - |
75 | | - X = np.array([[1.0, 2.0], [3.0, 4.0]]) |
76 | | - y = np.array([1.0, 2.0]) |
77 | | - |
78 | | - model = PySRRegressor( |
79 | | - niterations=0, |
80 | | - max_evals=0, |
81 | | - populations=1, |
82 | | - ncycles_per_iteration=0, |
83 | | - progress=False, |
84 | | - verbosity=0, |
85 | | - temp_equation_file=True, |
86 | | - parallelism="serial", |
87 | | - ) |
88 | | - |
89 | | - with self.assertRaises(RuntimeError) as cm: |
90 | | - self._mocked_fit(model, X, y, capture_options=capture_options) |
91 | | - self.assertEqual(str(cm.exception), "stop_after_options") |
92 | | - self.assertNotIn("input_stream", captured_kwargs) |
93 | | - |
94 | | - def test_explicit_stdin_passes_input_stream(self): |
95 | | - """When input_stream is 'stdin', Options should receive the kwarg.""" |
96 | | - captured_kwargs = {} |
97 | | - |
98 | | - def capture_options(**kwargs): |
99 | | - captured_kwargs.update(kwargs) |
100 | | - raise RuntimeError("stop_after_options") |
101 | | - |
102 | | - X = np.array([[1.0, 2.0], [3.0, 4.0]]) |
103 | | - y = np.array([1.0, 2.0]) |
104 | | - |
105 | | - model = PySRRegressor( |
106 | | - input_stream="stdin", |
107 | | - niterations=0, |
108 | | - max_evals=0, |
109 | | - populations=1, |
110 | | - ncycles_per_iteration=0, |
111 | | - progress=False, |
112 | | - verbosity=0, |
113 | | - temp_equation_file=True, |
114 | | - parallelism="serial", |
115 | | - ) |
116 | | - |
117 | | - with self.assertRaises(RuntimeError) as cm: |
118 | | - self._mocked_fit(model, X, y, capture_options=capture_options) |
119 | | - self.assertEqual(str(cm.exception), "stop_after_options") |
120 | | - self.assertIn("input_stream", captured_kwargs) |
121 | | - |
122 | | - def test_explicit_devnull_passes_input_stream(self): |
123 | | - """When input_stream is 'devnull', Options should receive the kwarg.""" |
124 | | - captured_kwargs = {} |
125 | | - |
126 | | - def capture_options(**kwargs): |
127 | | - captured_kwargs.update(kwargs) |
128 | | - raise RuntimeError("stop_after_options") |
129 | | - |
130 | | - X = np.array([[1.0, 2.0], [3.0, 4.0]]) |
131 | | - y = np.array([1.0, 2.0]) |
132 | 4 |
|
133 | | - model = PySRRegressor( |
134 | | - input_stream="devnull", |
135 | | - niterations=0, |
136 | | - max_evals=0, |
137 | | - populations=1, |
138 | | - ncycles_per_iteration=0, |
139 | | - progress=False, |
140 | | - verbosity=0, |
141 | | - temp_equation_file=True, |
142 | | - parallelism="serial", |
143 | | - ) |
144 | 5 |
|
145 | | - with self.assertRaises(RuntimeError) as cm: |
146 | | - self._mocked_fit(model, X, y, capture_options=capture_options) |
147 | | - self.assertEqual(str(cm.exception), "stop_after_options") |
148 | | - self.assertIn("input_stream", captured_kwargs) |
| 6 | +def test_default_is_none(): |
| 7 | + assert PySRRegressor().input_stream is None |
149 | 8 |
|
150 | 9 |
|
151 | | -def runtests(just_tests=False): |
152 | | - tests = [TestInputStream] |
153 | | - if just_tests: |
154 | | - return tests |
155 | | - suite = unittest.TestSuite() |
156 | | - loader = unittest.TestLoader() |
157 | | - for test in tests: |
158 | | - suite.addTests(loader.loadTestsFromTestCase(test)) |
159 | | - runner = unittest.TextTestRunner() |
160 | | - return runner.run(suite) |
| 10 | +def test_explicit_values(): |
| 11 | + assert PySRRegressor(input_stream="stdin").input_stream == "stdin" |
| 12 | + assert PySRRegressor(input_stream="devnull").input_stream == "devnull" |
0 commit comments