@@ -235,26 +235,34 @@ def _check_assertions(
235235 )
236236
237237
238- def _validate_elementwise_loss (custom_loss , * , has_weights : bool ) -> None :
239- """Validate that a Julia `elementwise_loss` is callable.
238+ def _validate_elementwise_loss (
239+ custom_loss , * , has_weights : bool , probe_value : Any = 1.0
240+ ) -> None :
241+ """Check whether a Julia `elementwise_loss` accepts the expected inputs.
240242
241- We require exactly 2 args unless the user passed `weights=` to fit,
242- in which case we require 3 args.
243+ The function probes the loss with two or three arguments, depending on
244+ whether weights are present, using the same dtype that fitting will use.
245+ If the probe fails, it raises a `ValueError` describing the expected
246+ signature.
243247 """
244248
245249 # This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
246250 # Only validate arity when the evaluated object is actually a function.
247251 if not jl_is_function (custom_loss ):
248252 return
249253
254+ probe_args = (
255+ (probe_value , probe_value , probe_value )
256+ if has_weights
257+ else (probe_value , probe_value )
258+ )
259+ ok = bool (jl .applicable (custom_loss , * probe_args ))
250260 if has_weights :
251- ok = bool (jl .applicable (custom_loss , 1.0 , 1.0 , 1.0 ))
252261 if not ok :
253262 raise ValueError (
254263 "`elementwise_loss` must accept (prediction, target, weight) when `weights` is passed to `fit`."
255264 )
256265 else :
257- ok = bool (jl .applicable (custom_loss , 1.0 , 1.0 ))
258266 if not ok :
259267 raise ValueError (
260268 "`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
@@ -2109,13 +2117,19 @@ def _run(
21092117 if isinstance (complexity_of_variables , list ):
21102118 complexity_of_variables = jl_array (complexity_of_variables )
21112119
2120+ np_dtype = self ._get_precision_mapped_dtype (np .array (X ))
2121+
21122122 custom_loss = jl .seval (
21132123 str (self .elementwise_loss )
21142124 if self .elementwise_loss is not None
21152125 else "nothing"
21162126 )
21172127 if self .elementwise_loss is not None :
2118- _validate_elementwise_loss (custom_loss , has_weights = weights is not None )
2128+ _validate_elementwise_loss (
2129+ custom_loss ,
2130+ has_weights = weights is not None ,
2131+ probe_value = np_dtype (1.0 ),
2132+ )
21192133
21202134 custom_full_objective = jl .seval (
21212135 str (self .loss_function ) if self .loss_function is not None else "nothing"
@@ -2304,8 +2318,6 @@ def _run(
23042318 self .julia_options_stream_ = jl_serialize (options )
23052319
23062320 # Convert data to desired precision
2307- test_X = np .array (X )
2308- np_dtype = self ._get_precision_mapped_dtype (test_X )
23092321
23102322 # This converts the data into a Julia array:
23112323 jl_X = jl_array (np .array (X , dtype = np_dtype ).T )
0 commit comments