@@ -258,25 +258,28 @@ def test_loss_function_varargs_objective_runs(self):
258258 y = np .array ([0.0 , 1.0 ])
259259 model .fit (X , y )
260260
261- def test_elementwise_loss_wrong_signature_errors_early (self ):
261+ def test_elementwise_loss_wrong_signature_warns (self ):
262262 """Validate `elementwise_loss` signature (prediction, target[, weights])."""
263- model = PySRRegressor (
264- niterations = 1 ,
265- populations = 1 ,
266- procs = 0 ,
267- progress = False ,
268- verbosity = 0 ,
269- temp_equation_file = True ,
270- binary_operators = ["+" ],
271- elementwise_loss = "myloss_bad_arity(a) = a" ,
263+ custom_loss = jl .seval ("myloss_bad_arity(a) = a" )
264+ with self .assertWarnsRegex (UserWarning , "elementwise_loss" ):
265+ _validate_elementwise_loss (
266+ custom_loss ,
267+ has_weights = False ,
268+ probe_value = np .float32 (1.0 ),
269+ )
270+
271+ def test_elementwise_loss_with_weights_requires_three_args_warns (self ):
272+ custom_loss = jl .seval (
273+ "myloss2(prediction, target) = (prediction - target)^2"
272274 )
273- X = np .array ([[0.0 ], [1.0 ]])
274- y = np .array ([0.0 , 1.0 ])
275- with self .assertRaises (ValueError ) as cm :
276- model .fit (X , y )
277- self .assertIn ("elementwise_loss" , str (cm .exception ))
275+ with self .assertWarnsRegex (UserWarning , "elementwise_loss" ):
276+ _validate_elementwise_loss (
277+ custom_loss ,
278+ has_weights = True ,
279+ probe_value = np .float32 (1.0 ),
280+ )
278281
279- def test_elementwise_loss_with_weights_requires_three_args (self ):
282+ def test_elementwise_loss_with_weights_accepts_three_args (self ):
280283 model = PySRRegressor (
281284 niterations = 1 ,
282285 populations = 1 ,
@@ -285,33 +288,51 @@ def test_elementwise_loss_with_weights_requires_three_args(self):
285288 verbosity = 0 ,
286289 temp_equation_file = True ,
287290 binary_operators = ["+" ],
288- elementwise_loss = "myloss2(prediction, target) = (prediction - target)^2" ,
291+ elementwise_loss = (
292+ "myloss3(prediction, target, weights) = weights * (prediction - target)^2"
293+ ),
289294 )
290295 X = np .array ([[0.0 ], [1.0 ]])
291296 y = np .array ([0.0 , 1.0 ])
292297 weights = np .array ([1.0 , 1.0 ])
293- with self .assertRaises (ValueError ) as cm :
294- model .fit (X , y , weights = weights )
295- self .assertIn ("elementwise_loss" , str (cm .exception ))
296- self .assertIn ("weights" , str (cm .exception ))
298+ model .fit (X , y , weights = weights )
297299
298- def test_elementwise_loss_with_weights_accepts_three_args (self ):
300+ def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss (self ):
301+ custom_loss = jl .seval (
302+ "(prediction::Float32, target::Float32) -> (prediction - target)^2"
303+ )
304+ with warnings .catch_warnings (record = True ) as caught :
305+ warnings .simplefilter ("always" )
306+ _validate_elementwise_loss (
307+ custom_loss ,
308+ has_weights = False ,
309+ probe_value = np .float32 (1.0 ),
310+ )
311+ self .assertEqual (len (caught ), 0 )
312+
313+ def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss (self ):
299314 model = PySRRegressor (
300315 niterations = 1 ,
301316 populations = 1 ,
302317 procs = 0 ,
303318 progress = False ,
304319 verbosity = 0 ,
320+ precision = 32 ,
305321 temp_equation_file = True ,
306322 binary_operators = ["+" ],
307323 elementwise_loss = (
308- "myloss3 (prediction, target, weights) = weights * (prediction - target)^2"
324+ "(prediction::Float32 , target::Float32) -> (prediction - target)^2"
309325 ),
310326 )
311- X = np .array ([[0.0 ], [1.0 ]])
312- y = np .array ([0.0 , 1.0 ])
313- weights = np .array ([1.0 , 1.0 ])
314- model .fit (X , y , weights = weights )
327+ X = np .array ([[0.0 ], [1.0 ]], dtype = np .float32 )
328+ y = np .array ([0.0 , 1.0 ], dtype = np .float32 )
329+ with warnings .catch_warnings (record = True ) as caught :
330+ warnings .simplefilter ("always" )
331+ model .fit (X , y )
332+ self .assertFalse (
333+ any ("elementwise_loss" in str (w .message ) for w in caught ),
334+ msg = [str (w .message ) for w in caught ],
335+ )
315336
316337 def test_validation_helpers_skip_nonfunction (self ):
317338 _validate_elementwise_loss (jl .seval ("1.0" ), has_weights = False )
0 commit comments