Skip to content

Commit 912c4c2

Browse files
MilesCranmerBotMilesCranmerpre-commit-ci[bot]
authored
fix: correct type in elementwise loss validation (MilesCranmer#1184)
* fix: warn on elementwise loss probe mismatch Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: simplify elementwise loss probe dtype Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * Revert "fix: simplify elementwise loss probe dtype" This reverts commit 775ab95. * fix: keep elementwise loss mismatch strict Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * docs: clarify elementwise loss validator docstring Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * refactor: rename loss probe parameter Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * test: update loss probe helper call Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> --------- Co-authored-by: MilesCranmerBot <milescranmerbot@users.noreply.github.com> Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c435527 commit 912c4c2

2 files changed

Lines changed: 49 additions & 9 deletions

File tree

pysr/sr.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,26 +235,34 @@ def _check_assertions(
235235
)
236236

237237

238-
def _validate_elementwise_loss(custom_loss, *, has_weights: bool) -> None:
239-
"""Validate that a Julia `elementwise_loss` is callable.
238+
def _validate_elementwise_loss(
239+
custom_loss, *, has_weights: bool, probe_value: Any = 1.0
240+
) -> None:
241+
"""Check whether a Julia `elementwise_loss` accepts the expected inputs.
240242
241-
We require exactly 2 args unless the user passed `weights=` to fit,
242-
in which case we require 3 args.
243+
The function probes the loss with two or three arguments, depending on
244+
whether weights are present, using the same dtype that fitting will use.
245+
If the probe fails, it raises a `ValueError` describing the expected
246+
signature.
243247
"""
244248

245249
# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
246250
# Only validate arity when the evaluated object is actually a function.
247251
if not jl_is_function(custom_loss):
248252
return
249253

254+
probe_args = (
255+
(probe_value, probe_value, probe_value)
256+
if has_weights
257+
else (probe_value, probe_value)
258+
)
259+
ok = bool(jl.applicable(custom_loss, *probe_args))
250260
if has_weights:
251-
ok = bool(jl.applicable(custom_loss, 1.0, 1.0, 1.0))
252261
if not ok:
253262
raise ValueError(
254263
"`elementwise_loss` must accept (prediction, target, weight) when `weights` is passed to `fit`."
255264
)
256265
else:
257-
ok = bool(jl.applicable(custom_loss, 1.0, 1.0))
258266
if not ok:
259267
raise ValueError(
260268
"`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
@@ -2109,13 +2117,19 @@ def _run(
21092117
if isinstance(complexity_of_variables, list):
21102118
complexity_of_variables = jl_array(complexity_of_variables)
21112119

2120+
np_dtype = self._get_precision_mapped_dtype(np.array(X))
2121+
21122122
custom_loss = jl.seval(
21132123
str(self.elementwise_loss)
21142124
if self.elementwise_loss is not None
21152125
else "nothing"
21162126
)
21172127
if self.elementwise_loss is not None:
2118-
_validate_elementwise_loss(custom_loss, has_weights=weights is not None)
2128+
_validate_elementwise_loss(
2129+
custom_loss,
2130+
has_weights=weights is not None,
2131+
probe_value=np_dtype(1.0),
2132+
)
21192133

21202134
custom_full_objective = jl.seval(
21212135
str(self.loss_function) if self.loss_function is not None else "nothing"
@@ -2304,8 +2318,6 @@ def _run(
23042318
self.julia_options_stream_ = jl_serialize(options)
23052319

23062320
# Convert data to desired precision
2307-
test_X = np.array(X)
2308-
np_dtype = self._get_precision_mapped_dtype(test_X)
23092321

23102322
# This converts the data into a Julia array:
23112323
jl_X = jl_array(np.array(X, dtype=np_dtype).T)

pysr/test/test_main.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,34 @@ def test_elementwise_loss_with_weights_accepts_three_args(self):
313313
weights = np.array([1.0, 1.0])
314314
model.fit(X, y, weights=weights)
315315

316+
def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss(self):
317+
custom_loss = jl.seval(
318+
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
319+
)
320+
_validate_elementwise_loss(
321+
custom_loss,
322+
has_weights=False,
323+
probe_value=np.float32(1.0),
324+
)
325+
326+
def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss(self):
327+
model = PySRRegressor(
328+
niterations=1,
329+
populations=1,
330+
procs=0,
331+
progress=False,
332+
verbosity=0,
333+
precision=32,
334+
temp_equation_file=True,
335+
binary_operators=["+"],
336+
elementwise_loss=(
337+
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
338+
),
339+
)
340+
X = np.array([[0.0], [1.0]], dtype=np.float32)
341+
y = np.array([0.0, 1.0], dtype=np.float32)
342+
model.fit(X, y)
343+
316344
def test_validation_helpers_skip_nonfunction(self):
317345
_validate_elementwise_loss(jl.seval("1.0"), has_weights=False)
318346

0 commit comments

Comments
 (0)