@@ -236,7 +236,7 @@ def _check_assertions(
236236
237237
238238def _validate_elementwise_loss (
239- custom_loss , * , has_weights : bool , probe_dtype : Callable [[ float ], Any ] = float
239+ custom_loss , * , has_weights : bool , probe_value : Any = 1.0
240240) -> None :
241241 """Validate that a Julia `elementwise_loss` is callable.
242242
@@ -251,7 +251,6 @@ def _validate_elementwise_loss(
251251 if not jl_is_function (custom_loss ):
252252 return
253253
254- probe_value = probe_dtype (1.0 )
255254 probe_args = (
256255 (probe_value , probe_value , probe_value )
257256 if has_weights
@@ -2126,6 +2125,7 @@ def _run(
21262125 complexity_of_variables = jl_array (complexity_of_variables )
21272126
21282127 np_dtype = self ._get_precision_mapped_dtype (np .array (X ))
2128+ probe_value = np_dtype (1.0 )
21292129
21302130 custom_loss = jl .seval (
21312131 str (self .elementwise_loss )
@@ -2136,7 +2136,7 @@ def _run(
21362136 _validate_elementwise_loss (
21372137 custom_loss ,
21382138 has_weights = weights is not None ,
2139- probe_dtype = np_dtype ,
2139+ probe_value = probe_value ,
21402140 )
21412141
21422142 custom_full_objective = jl .seval (
0 commit comments