Skip to content

Commit 8105905

Browse files
pre-commit-ci[bot]MilesCranmerBot
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 08851dd commit 8105905

2 files changed

Lines changed: 14 additions & 160 deletions

File tree

pysr/sr.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,9 +1551,9 @@ def get_best(
15511551

15521552
if index is not None:
15531553
if isinstance(self.equations_, list):
1554-
assert isinstance(index, list), (
1555-
"With multiple output features, index must be a list."
1556-
)
1554+
assert isinstance(
1555+
index, list
1556+
), "With multiple output features, index must be a list."
15571557
return [eq.iloc[i] for eq, i in zip(self.equations_, index)]
15581558
else:
15591559
equations_ = cast(pd.DataFrame, self.equations_)
@@ -2152,8 +2152,11 @@ def _run(
21522152
else "nothing"
21532153
)
21542154

2155-
if self.input_stream is not None:
2156-
input_stream = jl.seval(self.input_stream)
2155+
input_stream = (
2156+
jl.seval(self.input_stream)
2157+
if self.input_stream is not None
2158+
else jl.seval("nothing")
2159+
)
21572160

21582161
load_required_packages(
21592162
turbo=self.turbo,
@@ -2310,13 +2313,12 @@ def _run(
23102313
crossover_probability=self.crossover_probability,
23112314
skip_mutation_failures=self.skip_mutation_failures,
23122315
max_evals=self.max_evals,
2316+
input_stream=input_stream,
23132317
early_stop_condition=early_stop_condition,
23142318
seed=seed,
23152319
deterministic=self.deterministic,
23162320
define_helper_functions=False,
23172321
)
2318-
if self.input_stream is not None:
2319-
options_kwargs["input_stream"] = input_stream
23202322
options = SymbolicRegression.Options(**options_kwargs)
23212323

23222324
self.julia_options_stream_ = jl_serialize(options)

pysr/test/test_input_stream.py

Lines changed: 5 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,160 +1,12 @@
11
"""Tests for input_stream default behavior."""
22

3-
import unittest
4-
from unittest import mock
5-
6-
import numpy as np
7-
83
from pysr import PySRRegressor
9-
from pysr.sr import ALREADY_RAN
10-
11-
12-
class TestInputStream(unittest.TestCase):
13-
"""Verify input_stream defaults and backend passthrough."""
14-
15-
def test_default_is_none(self):
16-
"""By default, input_stream should be None so the Julia backend picks the stream."""
17-
model = PySRRegressor()
18-
self.assertIsNone(model.input_stream)
19-
20-
def test_explicit_stdin(self):
21-
model = PySRRegressor(input_stream="stdin")
22-
self.assertEqual(model.input_stream, "stdin")
23-
24-
def test_explicit_devnull(self):
25-
model = PySRRegressor(input_stream="devnull")
26-
self.assertEqual(model.input_stream, "devnull")
27-
28-
def _make_mock_jl(self):
29-
"""Return a MagicMock that satisfies the Julia calls made before Options()."""
30-
m = mock.MagicMock()
31-
m.seval.return_value = mock.MagicMock()
32-
m.Dict.return_value = mock.MagicMock()
33-
m.Pair.return_value = mock.MagicMock()
34-
m.Symbol.return_value = mock.MagicMock()
35-
m.NamedTuple.return_value = mock.MagicMock()
36-
return m
37-
38-
def _mocked_fit(self, model, X, y, *, capture_options):
39-
"""Run model.fit with enough Julia infrastructure mocked to reach Options()."""
40-
mock_jl = self._make_mock_jl()
41-
sr = mock.MagicMock()
42-
sr.MutationWeights.return_value = mock.MagicMock()
43-
sr.Options.side_effect = capture_options
44-
sr.equation_search.return_value = (mock.MagicMock(), b"mock")
45-
sr.SearchUtilsModule.generate_run_id.return_value = "test-run-id"
46-
47-
# Reset global so _run doesn't skip its first-run block.
48-
import pysr.sr as sr_module
49-
50-
old_already_ran = sr_module.ALREADY_RAN
51-
sr_module.ALREADY_RAN = False
52-
53-
try:
54-
with mock.patch("pysr.sr.jl", mock_jl):
55-
with mock.patch("pysr.sr.SymbolicRegression", sr):
56-
with mock.patch("pysr.sr.jl_array", return_value=mock.MagicMock()):
57-
with mock.patch("pysr.sr.jl_is_function", return_value=True):
58-
with mock.patch(
59-
"pysr.sr.jl_serialize", return_value=b"mock"
60-
):
61-
with mock.patch("pysr.sr.load_required_packages"):
62-
with mock.patch("pysr.sr._load_cluster_manager"):
63-
model.fit(X, y)
64-
finally:
65-
sr_module.ALREADY_RAN = old_already_ran
66-
67-
def test_default_omits_input_stream_from_options(self):
68-
"""When input_stream is None, Options should not receive the kwarg."""
69-
captured_kwargs = {}
70-
71-
def capture_options(**kwargs):
72-
captured_kwargs.update(kwargs)
73-
raise RuntimeError("stop_after_options")
74-
75-
X = np.array([[1.0, 2.0], [3.0, 4.0]])
76-
y = np.array([1.0, 2.0])
77-
78-
model = PySRRegressor(
79-
niterations=0,
80-
max_evals=0,
81-
populations=1,
82-
ncycles_per_iteration=0,
83-
progress=False,
84-
verbosity=0,
85-
temp_equation_file=True,
86-
parallelism="serial",
87-
)
88-
89-
with self.assertRaises(RuntimeError) as cm:
90-
self._mocked_fit(model, X, y, capture_options=capture_options)
91-
self.assertEqual(str(cm.exception), "stop_after_options")
92-
self.assertNotIn("input_stream", captured_kwargs)
93-
94-
def test_explicit_stdin_passes_input_stream(self):
95-
"""When input_stream is 'stdin', Options should receive the kwarg."""
96-
captured_kwargs = {}
97-
98-
def capture_options(**kwargs):
99-
captured_kwargs.update(kwargs)
100-
raise RuntimeError("stop_after_options")
101-
102-
X = np.array([[1.0, 2.0], [3.0, 4.0]])
103-
y = np.array([1.0, 2.0])
104-
105-
model = PySRRegressor(
106-
input_stream="stdin",
107-
niterations=0,
108-
max_evals=0,
109-
populations=1,
110-
ncycles_per_iteration=0,
111-
progress=False,
112-
verbosity=0,
113-
temp_equation_file=True,
114-
parallelism="serial",
115-
)
116-
117-
with self.assertRaises(RuntimeError) as cm:
118-
self._mocked_fit(model, X, y, capture_options=capture_options)
119-
self.assertEqual(str(cm.exception), "stop_after_options")
120-
self.assertIn("input_stream", captured_kwargs)
121-
122-
def test_explicit_devnull_passes_input_stream(self):
123-
"""When input_stream is 'devnull', Options should receive the kwarg."""
124-
captured_kwargs = {}
125-
126-
def capture_options(**kwargs):
127-
captured_kwargs.update(kwargs)
128-
raise RuntimeError("stop_after_options")
129-
130-
X = np.array([[1.0, 2.0], [3.0, 4.0]])
131-
y = np.array([1.0, 2.0])
1324

133-
model = PySRRegressor(
134-
input_stream="devnull",
135-
niterations=0,
136-
max_evals=0,
137-
populations=1,
138-
ncycles_per_iteration=0,
139-
progress=False,
140-
verbosity=0,
141-
temp_equation_file=True,
142-
parallelism="serial",
143-
)
1445

145-
with self.assertRaises(RuntimeError) as cm:
146-
self._mocked_fit(model, X, y, capture_options=capture_options)
147-
self.assertEqual(str(cm.exception), "stop_after_options")
148-
self.assertIn("input_stream", captured_kwargs)
6+
def test_default_is_none():
7+
assert PySRRegressor().input_stream is None
1498

1509

151-
def runtests(just_tests=False):
152-
tests = [TestInputStream]
153-
if just_tests:
154-
return tests
155-
suite = unittest.TestSuite()
156-
loader = unittest.TestLoader()
157-
for test in tests:
158-
suite.addTests(loader.loadTestsFromTestCase(test))
159-
runner = unittest.TextTestRunner()
160-
return runner.run(suite)
10+
def test_explicit_values():
11+
assert PySRRegressor(input_stream="stdin").input_stream == "stdin"
12+
assert PySRRegressor(input_stream="devnull").input_stream == "devnull"

0 commit comments

Comments
 (0)