Skip to content

Commit 775ab95

Browse files
fix: simplify elementwise loss probe dtype
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent ac4fd6e commit 775ab95

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

pysr/sr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _check_assertions(
236236

237237

238238
def _validate_elementwise_loss(
239-
custom_loss, *, has_weights: bool, probe_value: Any = 1.0
239+
custom_loss, *, has_weights: bool, probe_dtype: Callable[[float], Any] = float
240240
) -> None:
241241
"""Validate that a Julia `elementwise_loss` is callable.
242242
@@ -251,6 +251,7 @@ def _validate_elementwise_loss(
251251
if not jl_is_function(custom_loss):
252252
return
253253

254+
probe_value = probe_dtype(1.0)
254255
probe_args = (
255256
(probe_value, probe_value, probe_value)
256257
if has_weights
@@ -2125,7 +2126,6 @@ def _run(
21252126
complexity_of_variables = jl_array(complexity_of_variables)
21262127

21272128
np_dtype = self._get_precision_mapped_dtype(np.array(X))
2128-
probe_value = np_dtype(1.0)
21292129

21302130
custom_loss = jl.seval(
21312131
str(self.elementwise_loss)
@@ -2136,7 +2136,7 @@ def _run(
21362136
_validate_elementwise_loss(
21372137
custom_loss,
21382138
has_weights=weights is not None,
2139-
probe_value=probe_value,
2139+
probe_dtype=np_dtype,
21402140
)
21412141

21422142
custom_full_objective = jl.seval(

pysr/test/test_main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_elementwise_loss_wrong_signature_warns(self):
265265
_validate_elementwise_loss(
266266
custom_loss,
267267
has_weights=False,
268-
probe_value=np.float32(1.0),
268+
probe_dtype=np.float32,
269269
)
270270

271271
def test_elementwise_loss_with_weights_requires_three_args_warns(self):
@@ -274,7 +274,7 @@ def test_elementwise_loss_with_weights_requires_three_args_warns(self):
274274
_validate_elementwise_loss(
275275
custom_loss,
276276
has_weights=True,
277-
probe_value=np.float32(1.0),
277+
probe_dtype=np.float32,
278278
)
279279

280280
def test_elementwise_loss_with_weights_accepts_three_args(self):
@@ -304,7 +304,7 @@ def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss(self):
304304
_validate_elementwise_loss(
305305
custom_loss,
306306
has_weights=False,
307-
probe_value=np.float32(1.0),
307+
probe_dtype=np.float32,
308308
)
309309
self.assertEqual(len(caught), 0)
310310

0 commit comments

Comments
 (0)