Skip to content

Commit 289f1ee

Browse files
fix: warn on elementwise loss probe mismatch
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent c435527 commit 289f1ee

2 files changed

Lines changed: 85 additions & 44 deletions

File tree

pysr/sr.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -235,31 +235,46 @@ def _check_assertions(
235235
)
236236

237237

238-
def _validate_elementwise_loss(custom_loss, *, has_weights: bool) -> None:
238+
def _validate_elementwise_loss(
239+
custom_loss, *, has_weights: bool, probe_value: Any = 1.0
240+
) -> None:
239241
"""Validate that a Julia `elementwise_loss` is callable.
240242
241-
We require exactly 2 args unless the user passed `weights=` to fit,
242-
in which case we require 3 args.
243+
We probe with the dtype that the Julia backend will use, which avoids
244+
falsely rejecting strictly typed losses such as `(::Float32, ::Float32)`.
245+
If the probe still fails, emit a warning rather than raising so Julia can
246+
surface the real `MethodError` during fitting for advanced custom losses.
243247
"""
244248

245249
# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
246250
# Only validate arity when the evaluated object is actually a function.
247251
if not jl_is_function(custom_loss):
248252
return
249253

254+
probe_args = (
255+
(probe_value, probe_value, probe_value)
256+
if has_weights
257+
else (probe_value, probe_value)
258+
)
259+
ok = bool(jl.applicable(custom_loss, *probe_args))
260+
if ok:
261+
return
262+
250263
if has_weights:
251-
ok = bool(jl.applicable(custom_loss, 1.0, 1.0, 1.0))
252-
if not ok:
253-
raise ValueError(
254-
"`elementwise_loss` must accept (prediction, target, weight) when `weights` is passed to `fit`."
255-
)
264+
warnings.warn(
265+
"`elementwise_loss` did not match the probed (prediction, target, weight) signature "
266+
"for the dtype used during fitting. Continuing anyway so Julia can surface a more "
267+
"specific `MethodError` if needed.",
268+
stacklevel=2,
269+
)
256270
else:
257-
ok = bool(jl.applicable(custom_loss, 1.0, 1.0))
258-
if not ok:
259-
raise ValueError(
260-
"`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
261-
"`loss_function` or `loss_function_expression`."
262-
)
271+
warnings.warn(
272+
"`elementwise_loss` did not match the probed (prediction, target) signature for "
273+
"the dtype used during fitting. If you intended a full objective, use "
274+
"`loss_function` or `loss_function_expression`. Continuing anyway so Julia can "
275+
"surface a more specific `MethodError` if needed.",
276+
stacklevel=2,
277+
)
263278

264279

265280
def _validate_custom_objective(
@@ -2109,13 +2124,20 @@ def _run(
21092124
if isinstance(complexity_of_variables, list):
21102125
complexity_of_variables = jl_array(complexity_of_variables)
21112126

2127+
np_dtype = self._get_precision_mapped_dtype(np.array(X))
2128+
probe_value = np_dtype(1.0)
2129+
21122130
custom_loss = jl.seval(
21132131
str(self.elementwise_loss)
21142132
if self.elementwise_loss is not None
21152133
else "nothing"
21162134
)
21172135
if self.elementwise_loss is not None:
2118-
_validate_elementwise_loss(custom_loss, has_weights=weights is not None)
2136+
_validate_elementwise_loss(
2137+
custom_loss,
2138+
has_weights=weights is not None,
2139+
probe_value=probe_value,
2140+
)
21192141

21202142
custom_full_objective = jl.seval(
21212143
str(self.loss_function) if self.loss_function is not None else "nothing"
@@ -2304,8 +2326,6 @@ def _run(
23042326
self.julia_options_stream_ = jl_serialize(options)
23052327

23062328
# Convert data to desired precision
2307-
test_X = np.array(X)
2308-
np_dtype = self._get_precision_mapped_dtype(test_X)
23092329

23102330
# This converts the data into a Julia array:
23112331
jl_X = jl_array(np.array(X, dtype=np_dtype).T)

pysr/test/test_main.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -258,25 +258,28 @@ def test_loss_function_varargs_objective_runs(self):
258258
y = np.array([0.0, 1.0])
259259
model.fit(X, y)
260260

261-
def test_elementwise_loss_wrong_signature_errors_early(self):
261+
def test_elementwise_loss_wrong_signature_warns(self):
262262
"""Validate `elementwise_loss` signature (prediction, target[, weights])."""
263-
model = PySRRegressor(
264-
niterations=1,
265-
populations=1,
266-
procs=0,
267-
progress=False,
268-
verbosity=0,
269-
temp_equation_file=True,
270-
binary_operators=["+"],
271-
elementwise_loss="myloss_bad_arity(a) = a",
263+
custom_loss = jl.seval("myloss_bad_arity(a) = a")
264+
with self.assertWarnsRegex(UserWarning, "elementwise_loss"):
265+
_validate_elementwise_loss(
266+
custom_loss,
267+
has_weights=False,
268+
probe_value=np.float32(1.0),
269+
)
270+
271+
def test_elementwise_loss_with_weights_requires_three_args_warns(self):
272+
custom_loss = jl.seval(
273+
"myloss2(prediction, target) = (prediction - target)^2"
272274
)
273-
X = np.array([[0.0], [1.0]])
274-
y = np.array([0.0, 1.0])
275-
with self.assertRaises(ValueError) as cm:
276-
model.fit(X, y)
277-
self.assertIn("elementwise_loss", str(cm.exception))
275+
with self.assertWarnsRegex(UserWarning, "elementwise_loss"):
276+
_validate_elementwise_loss(
277+
custom_loss,
278+
has_weights=True,
279+
probe_value=np.float32(1.0),
280+
)
278281

279-
def test_elementwise_loss_with_weights_requires_three_args(self):
282+
def test_elementwise_loss_with_weights_accepts_three_args(self):
280283
model = PySRRegressor(
281284
niterations=1,
282285
populations=1,
@@ -285,33 +288,51 @@ def test_elementwise_loss_with_weights_requires_three_args(self):
285288
verbosity=0,
286289
temp_equation_file=True,
287290
binary_operators=["+"],
288-
elementwise_loss="myloss2(prediction, target) = (prediction - target)^2",
291+
elementwise_loss=(
292+
"myloss3(prediction, target, weights) = weights * (prediction - target)^2"
293+
),
289294
)
290295
X = np.array([[0.0], [1.0]])
291296
y = np.array([0.0, 1.0])
292297
weights = np.array([1.0, 1.0])
293-
with self.assertRaises(ValueError) as cm:
294-
model.fit(X, y, weights=weights)
295-
self.assertIn("elementwise_loss", str(cm.exception))
296-
self.assertIn("weights", str(cm.exception))
298+
model.fit(X, y, weights=weights)
297299

298-
def test_elementwise_loss_with_weights_accepts_three_args(self):
300+
def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss(self):
301+
custom_loss = jl.seval(
302+
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
303+
)
304+
with warnings.catch_warnings(record=True) as caught:
305+
warnings.simplefilter("always")
306+
_validate_elementwise_loss(
307+
custom_loss,
308+
has_weights=False,
309+
probe_value=np.float32(1.0),
310+
)
311+
self.assertEqual(len(caught), 0)
312+
313+
def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss(self):
299314
model = PySRRegressor(
300315
niterations=1,
301316
populations=1,
302317
procs=0,
303318
progress=False,
304319
verbosity=0,
320+
precision=32,
305321
temp_equation_file=True,
306322
binary_operators=["+"],
307323
elementwise_loss=(
308-
"myloss3(prediction, target, weights) = weights * (prediction - target)^2"
324+
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
309325
),
310326
)
311-
X = np.array([[0.0], [1.0]])
312-
y = np.array([0.0, 1.0])
313-
weights = np.array([1.0, 1.0])
314-
model.fit(X, y, weights=weights)
327+
X = np.array([[0.0], [1.0]], dtype=np.float32)
328+
y = np.array([0.0, 1.0], dtype=np.float32)
329+
with warnings.catch_warnings(record=True) as caught:
330+
warnings.simplefilter("always")
331+
model.fit(X, y)
332+
self.assertFalse(
333+
any("elementwise_loss" in str(w.message) for w in caught),
334+
msg=[str(w.message) for w in caught],
335+
)
315336

316337
def test_validation_helpers_skip_nonfunction(self):
317338
_validate_elementwise_loss(jl.seval("1.0"), has_weights=False)

0 commit comments

Comments
 (0)