@@ -258,24 +258,42 @@ def test_loss_function_varargs_objective_runs(self):
258258 y = np .array ([0.0 , 1.0 ])
259259 model .fit (X , y )
260260
261- def test_elementwise_loss_wrong_signature_warns (self ):
261+ def test_elementwise_loss_wrong_signature_errors_early (self ):
262262 """Validate `elementwise_loss` signature (prediction, target[, weights])."""
263- custom_loss = jl .seval ("myloss_bad_arity(a) = a" )
264- with self .assertWarnsRegex (UserWarning , "elementwise_loss" ):
265- _validate_elementwise_loss (
266- custom_loss ,
267- has_weights = False ,
268- probe_value = np .float32 (1.0 ),
269- )
263+ model = PySRRegressor (
264+ niterations = 1 ,
265+ populations = 1 ,
266+ procs = 0 ,
267+ progress = False ,
268+ verbosity = 0 ,
269+ temp_equation_file = True ,
270+ binary_operators = ["+" ],
271+ elementwise_loss = "myloss_bad_arity(a) = a" ,
272+ )
273+ X = np .array ([[0.0 ], [1.0 ]])
274+ y = np .array ([0.0 , 1.0 ])
275+ with self .assertRaises (ValueError ) as cm :
276+ model .fit (X , y )
277+ self .assertIn ("elementwise_loss" , str (cm .exception ))
270278
271- def test_elementwise_loss_with_weights_requires_three_args_warns (self ):
272- custom_loss = jl .seval ("myloss2(prediction, target) = (prediction - target)^2" )
273- with self .assertWarnsRegex (UserWarning , "elementwise_loss" ):
274- _validate_elementwise_loss (
275- custom_loss ,
276- has_weights = True ,
277- probe_value = np .float32 (1.0 ),
278- )
279+ def test_elementwise_loss_with_weights_requires_three_args (self ):
280+ model = PySRRegressor (
281+ niterations = 1 ,
282+ populations = 1 ,
283+ procs = 0 ,
284+ progress = False ,
285+ verbosity = 0 ,
286+ temp_equation_file = True ,
287+ binary_operators = ["+" ],
288+ elementwise_loss = "myloss2(prediction, target) = (prediction - target)^2" ,
289+ )
290+ X = np .array ([[0.0 ], [1.0 ]])
291+ y = np .array ([0.0 , 1.0 ])
292+ weights = np .array ([1.0 , 1.0 ])
293+ with self .assertRaises (ValueError ) as cm :
294+ model .fit (X , y , weights = weights )
295+ self .assertIn ("elementwise_loss" , str (cm .exception ))
296+ self .assertIn ("weights" , str (cm .exception ))
279297
280298 def test_elementwise_loss_with_weights_accepts_three_args (self ):
281299 model = PySRRegressor (
@@ -299,14 +317,11 @@ def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss(self):
299317 custom_loss = jl .seval (
300318 "(prediction::Float32, target::Float32) -> (prediction - target)^2"
301319 )
302- with warnings .catch_warnings (record = True ) as caught :
303- warnings .simplefilter ("always" )
304- _validate_elementwise_loss (
305- custom_loss ,
306- has_weights = False ,
307- probe_value = np .float32 (1.0 ),
308- )
309- self .assertEqual (len (caught ), 0 )
320+ _validate_elementwise_loss (
321+ custom_loss ,
322+ has_weights = False ,
323+ probe_dtype = np .float32 ,
324+ )
310325
311326 def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss (self ):
312327 model = PySRRegressor (
@@ -324,13 +339,7 @@ def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss(self):
324339 )
325340 X = np .array ([[0.0 ], [1.0 ]], dtype = np .float32 )
326341 y = np .array ([0.0 , 1.0 ], dtype = np .float32 )
327- with warnings .catch_warnings (record = True ) as caught :
328- warnings .simplefilter ("always" )
329- model .fit (X , y )
330- self .assertFalse (
331- any ("elementwise_loss" in str (w .message ) for w in caught ),
332- msg = [str (w .message ) for w in caught ],
333- )
342+ model .fit (X , y )
334343
335344 def test_validation_helpers_skip_nonfunction (self ):
336345 _validate_elementwise_loss (jl .seval ("1.0" ), has_weights = False )
0 commit comments