Skip to content

Commit 1a2d050

Browse files
Revert "fix: simplify elementwise loss probe dtype"
This reverts commit 775ab95.
1 parent 775ab95 commit 1a2d050

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

pysr/sr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _check_assertions(
236236

237237

238238
def _validate_elementwise_loss(
239-
custom_loss, *, has_weights: bool, probe_dtype: Callable[[float], Any] = float
239+
custom_loss, *, has_weights: bool, probe_value: Any = 1.0
240240
) -> None:
241241
"""Validate that a Julia `elementwise_loss` is callable.
242242
@@ -251,7 +251,6 @@ def _validate_elementwise_loss(
251251
if not jl_is_function(custom_loss):
252252
return
253253

254-
probe_value = probe_dtype(1.0)
255254
probe_args = (
256255
(probe_value, probe_value, probe_value)
257256
if has_weights
@@ -2126,6 +2125,7 @@ def _run(
21262125
complexity_of_variables = jl_array(complexity_of_variables)
21272126

21282127
np_dtype = self._get_precision_mapped_dtype(np.array(X))
2128+
probe_value = np_dtype(1.0)
21292129

21302130
custom_loss = jl.seval(
21312131
str(self.elementwise_loss)
@@ -2136,7 +2136,7 @@ def _run(
21362136
_validate_elementwise_loss(
21372137
custom_loss,
21382138
has_weights=weights is not None,
2139-
probe_dtype=np_dtype,
2139+
probe_value=probe_value,
21402140
)
21412141

21422142
custom_full_objective = jl.seval(

pysr/test/test_main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_elementwise_loss_wrong_signature_warns(self):
265265
_validate_elementwise_loss(
266266
custom_loss,
267267
has_weights=False,
268-
probe_dtype=np.float32,
268+
probe_value=np.float32(1.0),
269269
)
270270

271271
def test_elementwise_loss_with_weights_requires_three_args_warns(self):
@@ -274,7 +274,7 @@ def test_elementwise_loss_with_weights_requires_three_args_warns(self):
274274
_validate_elementwise_loss(
275275
custom_loss,
276276
has_weights=True,
277-
probe_dtype=np.float32,
277+
probe_value=np.float32(1.0),
278278
)
279279

280280
def test_elementwise_loss_with_weights_accepts_three_args(self):
@@ -304,7 +304,7 @@ def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss(self):
304304
_validate_elementwise_loss(
305305
custom_loss,
306306
has_weights=False,
307-
probe_dtype=np.float32,
307+
probe_value=np.float32(1.0),
308308
)
309309
self.assertEqual(len(caught), 0)
310310

0 commit comments

Comments
 (0)