@@ -242,6 +242,16 @@ def test_non_pulse_2d_input_full_sequence(self):
242242 loss = t .train_batch (x , y , thinking_steps = 10 , full_sequence = True )
243243 assert isinstance (loss , float )
244244
245+ def test_gradient_accumulation_steps_must_be_positive (self ):
246+ model = _model ()
247+ t = _trainer (model )
248+ x = _batch ()
249+ y = _targets ()
250+ with pytest .raises (ValueError ):
251+ t .train_batch (x , y , thinking_steps = 2 , gradient_accumulation_steps = 0 )
252+ with pytest .raises (ValueError ):
253+ t .train_batch (x , y , thinking_steps = 2 , gradient_accumulation_steps = - 1 )
254+
245255
246256# ===========================================================================
247257# predict
@@ -341,6 +351,30 @@ def test_fit_loss_trend_downward_on_simple_data(self):
341351 history = t .fit (x , y , epochs = 20 , batch_size = n , thinking_steps = 5 , verbose = False )
342352 assert history [- 1 ] < history [0 ], "Loss should decrease over training"
343353
354+ def test_fit_empty_dataset_raises (self ):
355+ model = _model ()
356+ t = _trainer (model )
357+ x = torch .empty (0 , 5 )
358+ y = torch .empty (0 , 2 )
359+ with pytest .raises (ValueError ):
360+ t .fit (x , y , epochs = 1 , batch_size = 4 , thinking_steps = 2 , verbose = False )
361+
362+ def test_fit_length_mismatch_raises (self ):
363+ model = _model ()
364+ t = _trainer (model )
365+ x = torch .randn (3 , 5 )
366+ y = torch .randn (2 , 2 )
367+ with pytest .raises (ValueError ):
368+ t .fit (x , y , epochs = 1 , batch_size = 2 , thinking_steps = 2 , verbose = False )
369+
370+ def test_fit_invalid_batch_size_raises (self ):
371+ model = _model ()
372+ t = _trainer (model )
373+ x = torch .randn (4 , 5 )
374+ y = torch .randn (4 , 2 )
375+ with pytest .raises (ValueError ):
376+ t .fit (x , y , epochs = 1 , batch_size = 0 , thinking_steps = 2 , verbose = False )
377+
344378
345379# ===========================================================================
346380# regenerate_synapses
0 commit comments