Skip to content

Commit e8368af

Browse files
refactor: rename loss probe parameter
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent 2a866c2 commit e8368af

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

pysr/sr.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,21 +236,21 @@ def _check_assertions(
236236

237237

238238
def _validate_elementwise_loss(
239-
custom_loss, *, has_weights: bool, probe_dtype: Callable[[float], Any] = float
239+
custom_loss, *, has_weights: bool, probe_value: Any = 1.0
240240
) -> None:
241241
"""Check whether a Julia `elementwise_loss` accepts the expected inputs.
242242
243243
The function probes the loss with two or three arguments, depending on
244-
whether weights are present, using the dtype that fitting will use. If the
245-
probe fails, it raises a `ValueError` describing the expected signature.
244+
whether weights are present, using the same dtype that fitting will use.
245+
If the probe fails, it raises a `ValueError` describing the expected
246+
signature.
246247
"""
247248

248249
# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
249250
# Only validate arity when the evaluated object is actually a function.
250251
if not jl_is_function(custom_loss):
251252
return
252253

253-
probe_value = probe_dtype(1.0)
254254
probe_args = (
255255
(probe_value, probe_value, probe_value)
256256
if has_weights
@@ -2128,7 +2128,7 @@ def _run(
21282128
_validate_elementwise_loss(
21292129
custom_loss,
21302130
has_weights=weights is not None,
2131-
probe_dtype=np_dtype,
2131+
probe_value=np_dtype(1.0),
21322132
)
21332133

21342134
custom_full_objective = jl.seval(

0 commit comments

Comments
 (0)