Skip to content

Commit 8d71c2f

Browse files
fix: keep elementwise loss mismatch strict
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent 1a2d050 commit 8d71c2f

2 files changed

Lines changed: 52 additions & 52 deletions

File tree

pysr/sr.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -236,45 +236,37 @@ def _check_assertions(
236236

237237

238238
def _validate_elementwise_loss(
239-
custom_loss, *, has_weights: bool, probe_value: Any = 1.0
239+
custom_loss, *, has_weights: bool, probe_dtype: Callable[[float], Any] = float
240240
) -> None:
241241
"""Validate that a Julia `elementwise_loss` is callable.
242242
243243
We probe with the dtype that the Julia backend will use, which avoids
244244
falsely rejecting strictly typed losses such as `(::Float32, ::Float32)`.
245-
If the probe still fails, emit a warning rather than raising so Julia can
246-
surface the real `MethodError` during fitting for advanced custom losses.
247245
"""
248246

249247
# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
250248
# Only validate arity when the evaluated object is actually a function.
251249
if not jl_is_function(custom_loss):
252250
return
253251

252+
probe_value = probe_dtype(1.0)
254253
probe_args = (
255254
(probe_value, probe_value, probe_value)
256255
if has_weights
257256
else (probe_value, probe_value)
258257
)
259258
ok = bool(jl.applicable(custom_loss, *probe_args))
260-
if ok:
261-
return
262-
263259
if has_weights:
264-
warnings.warn(
265-
"`elementwise_loss` did not match the probed (prediction, target, weight) signature "
266-
"for the dtype used during fitting. Continuing anyway so Julia can surface a more "
267-
"specific `MethodError` if needed.",
268-
stacklevel=2,
269-
)
260+
if not ok:
261+
raise ValueError(
262+
"`elementwise_loss` must accept (prediction, target, weight) when `weights` is passed to `fit`."
263+
)
270264
else:
271-
warnings.warn(
272-
"`elementwise_loss` did not match the probed (prediction, target) signature for "
273-
"the dtype used during fitting. If you intended a full objective, use "
274-
"`loss_function` or `loss_function_expression`. Continuing anyway so Julia can "
275-
"surface a more specific `MethodError` if needed.",
276-
stacklevel=2,
277-
)
265+
if not ok:
266+
raise ValueError(
267+
"`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
268+
"`loss_function` or `loss_function_expression`."
269+
)
278270

279271

280272
def _validate_custom_objective(
@@ -2125,7 +2117,6 @@ def _run(
21252117
complexity_of_variables = jl_array(complexity_of_variables)
21262118

21272119
np_dtype = self._get_precision_mapped_dtype(np.array(X))
2128-
probe_value = np_dtype(1.0)
21292120

21302121
custom_loss = jl.seval(
21312122
str(self.elementwise_loss)
@@ -2136,7 +2127,7 @@ def _run(
21362127
_validate_elementwise_loss(
21372128
custom_loss,
21382129
has_weights=weights is not None,
2139-
probe_value=probe_value,
2130+
probe_dtype=np_dtype,
21402131
)
21412132

21422133
custom_full_objective = jl.seval(

pysr/test/test_main.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -258,24 +258,42 @@ def test_loss_function_varargs_objective_runs(self):
258258
y = np.array([0.0, 1.0])
259259
model.fit(X, y)
260260

261-
def test_elementwise_loss_wrong_signature_warns(self):
261+
def test_elementwise_loss_wrong_signature_errors_early(self):
262262
"""Validate `elementwise_loss` signature (prediction, target[, weights])."""
263-
custom_loss = jl.seval("myloss_bad_arity(a) = a")
264-
with self.assertWarnsRegex(UserWarning, "elementwise_loss"):
265-
_validate_elementwise_loss(
266-
custom_loss,
267-
has_weights=False,
268-
probe_value=np.float32(1.0),
269-
)
263+
model = PySRRegressor(
264+
niterations=1,
265+
populations=1,
266+
procs=0,
267+
progress=False,
268+
verbosity=0,
269+
temp_equation_file=True,
270+
binary_operators=["+"],
271+
elementwise_loss="myloss_bad_arity(a) = a",
272+
)
273+
X = np.array([[0.0], [1.0]])
274+
y = np.array([0.0, 1.0])
275+
with self.assertRaises(ValueError) as cm:
276+
model.fit(X, y)
277+
self.assertIn("elementwise_loss", str(cm.exception))
270278

271-
def test_elementwise_loss_with_weights_requires_three_args_warns(self):
272-
custom_loss = jl.seval("myloss2(prediction, target) = (prediction - target)^2")
273-
with self.assertWarnsRegex(UserWarning, "elementwise_loss"):
274-
_validate_elementwise_loss(
275-
custom_loss,
276-
has_weights=True,
277-
probe_value=np.float32(1.0),
278-
)
279+
def test_elementwise_loss_with_weights_requires_three_args(self):
280+
model = PySRRegressor(
281+
niterations=1,
282+
populations=1,
283+
procs=0,
284+
progress=False,
285+
verbosity=0,
286+
temp_equation_file=True,
287+
binary_operators=["+"],
288+
elementwise_loss="myloss2(prediction, target) = (prediction - target)^2",
289+
)
290+
X = np.array([[0.0], [1.0]])
291+
y = np.array([0.0, 1.0])
292+
weights = np.array([1.0, 1.0])
293+
with self.assertRaises(ValueError) as cm:
294+
model.fit(X, y, weights=weights)
295+
self.assertIn("elementwise_loss", str(cm.exception))
296+
self.assertIn("weights", str(cm.exception))
279297

280298
def test_elementwise_loss_with_weights_accepts_three_args(self):
281299
model = PySRRegressor(
@@ -299,14 +317,11 @@ def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss(self):
299317
custom_loss = jl.seval(
300318
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
301319
)
302-
with warnings.catch_warnings(record=True) as caught:
303-
warnings.simplefilter("always")
304-
_validate_elementwise_loss(
305-
custom_loss,
306-
has_weights=False,
307-
probe_value=np.float32(1.0),
308-
)
309-
self.assertEqual(len(caught), 0)
320+
_validate_elementwise_loss(
321+
custom_loss,
322+
has_weights=False,
323+
probe_dtype=np.float32,
324+
)
310325

311326
def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss(self):
312327
model = PySRRegressor(
@@ -324,13 +339,7 @@ def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss(self):
324339
)
325340
X = np.array([[0.0], [1.0]], dtype=np.float32)
326341
y = np.array([0.0, 1.0], dtype=np.float32)
327-
with warnings.catch_warnings(record=True) as caught:
328-
warnings.simplefilter("always")
329-
model.fit(X, y)
330-
self.assertFalse(
331-
any("elementwise_loss" in str(w.message) for w in caught),
332-
msg=[str(w.message) for w in caught],
333-
)
342+
model.fit(X, y)
334343

335344
def test_validation_helpers_skip_nonfunction(self):
336345
_validate_elementwise_loss(jl.seval("1.0"), has_weights=False)

0 commit comments

Comments
 (0)