Skip to content

Commit 2a866c2

Browse files
docs: clarify elementwise loss validator docstring
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent 8d71c2f commit 2a866c2

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

pysr/sr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,11 @@ def _check_assertions(
238238
def _validate_elementwise_loss(
239239
custom_loss, *, has_weights: bool, probe_dtype: Callable[[float], Any] = float
240240
) -> None:
241-
"""Validate that a Julia `elementwise_loss` is callable.
241+
"""Check whether a Julia `elementwise_loss` accepts the expected inputs.
242242
243-
We probe with the dtype that the Julia backend will use, which avoids
244-
falsely rejecting strictly typed losses such as `(::Float32, ::Float32)`.
243+
The function probes the loss with two or three arguments, depending on
244+
whether weights are present, using the dtype that fitting will use. If the
245+
probe fails, it raises a `ValueError` describing the expected signature.
245246
"""
246247

247248
# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.

0 commit comments

Comments
 (0)