Skip to content

Commit 1bf44c1

Browse files
fix: convert num_features dict keys to Julia Symbols
Closes MilesCranmer#811 Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent 9283914 commit 1bf44c1

3 files changed

Lines changed: 27 additions & 6 deletions

File tree

pysr/expression_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def julia_expression_options(self):
293293
num_features = if num_features === nothing
294294
nothing
295295
else
296-
(; num_features...)
296+
NamedTuple(Symbol(k) => v for (k, v) in num_features)
297297
end
298298
structure = SymbolicRegression.TemplateStructure{tuple_symbol}(combine, num_features)
299299
return (; structure)

pysr/sr.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,15 +1590,16 @@ def _setup_equation_file(self):
15901590
if self.output_directory is not None:
15911591
assert self.output_directory_ == self.output_directory
15921592
else:
1593-
self.output_directory_ = (
1594-
tempfile.mkdtemp()
1595-
if self.temp_equation_file
1596-
else (
1593+
if self.temp_equation_file:
1594+
if self.tempdir is not None:
1595+
Path(self.tempdir).mkdir(parents=True, exist_ok=True)
1596+
self.output_directory_ = tempfile.mkdtemp(dir=self.tempdir)
1597+
else:
1598+
self.output_directory_ = (
15971599
"outputs"
15981600
if self.output_directory is None
15991601
else self.output_directory
16001602
)
1601-
)
16021603
self.run_id_ = (
16031604
cast(str, SymbolicRegression.SearchUtilsModule.generate_run_id())
16041605
if self.run_id is None

pysr/test/test_main.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def setUp(self):
7777
self.rstate = np.random.RandomState(0)
7878
self.X = self.rstate.randn(100, 5)
7979

80+
def test_temp_equation_file_respects_tempdir(self):
81+
with tempfile.TemporaryDirectory() as d:
82+
tempdir = Path(d) / "pysr-temp"
83+
model = PySRRegressor(
84+
temp_equation_file=True, tempdir=str(tempdir), run_id="t"
85+
)
86+
model._setup_equation_file()
87+
self.assertEqual(Path(model.output_directory_).parent, tempdir)
88+
8089
def test_linear_relation(self):
8190
y = self.X[:, 0]
8291
model = PySRRegressor(
@@ -2237,6 +2246,17 @@ def test_process_constraints_swaps_multiplication_constraints(self):
22372246

22382247

22392248
class TestTemplateExpressionSpec(unittest.TestCase):
2249+
def test_num_features_symbol_keys(self):
2250+
# ponytail: one check — dict keys must reach Julia as Symbols
2251+
spec = TemplateExpressionSpec(
2252+
["f", "g"],
2253+
"combine(fs, vars) = fs.f(vars[1], vars[2]) + fs.g(vars[3])",
2254+
{"f": 2, "g": 1},
2255+
)
2256+
options = spec.julia_expression_options()
2257+
names = jl.seval("x -> propertynames(x.structure.num_features)")(options)
2258+
self.assertEqual(names, (jl.Symbol("f"), jl.Symbol("g")))
2259+
22402260
def _check_macro_str(self, spec, expected_str):
22412261
self.assertEqual(
22422262
spec._template_macro_str().strip(), dedent(expected_str).strip()

0 commit comments

Comments
 (0)