@@ -236,21 +236,21 @@ def _check_assertions(
236236
237237
238238def _validate_elementwise_loss (
239- custom_loss , * , has_weights : bool , probe_dtype : Callable [[ float ], Any ] = float
239+ custom_loss , * , has_weights : bool , probe_value : Any = 1.0
240240) -> None :
241241 """Check whether a Julia `elementwise_loss` accepts the expected inputs.
242242
243243 The function probes the loss with two or three arguments, depending on
244- whether weights are present, using the dtype that fitting will use. If the
245- probe fails, it raises a `ValueError` describing the expected signature.
244+ whether weights are present, using the same dtype that fitting will use.
245+ If the probe fails, it raises a `ValueError` describing the expected
246+ signature.
246247 """
247248
248249 # This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
249250 # Only validate arity when the evaluated object is actually a function.
250251 if not jl_is_function (custom_loss ):
251252 return
252253
253- probe_value = probe_dtype (1.0 )
254254 probe_args = (
255255 (probe_value , probe_value , probe_value )
256256 if has_weights
@@ -2128,7 +2128,7 @@ def _run(
21282128 _validate_elementwise_loss (
21292129 custom_loss ,
21302130 has_weights = weights is not None ,
2131- probe_dtype = np_dtype ,
2131+ probe_value = np_dtype ( 1.0 ) ,
21322132 )
21332133
21342134 custom_full_objective = jl .seval (
0 commit comments