@@ -83,6 +83,10 @@ def read_weights(model):
8383 return buf
8484
8585
86+ def read_bias (model ):
87+ return model .get_bias ()
88+
89+
8690def assert_raises_value_error (func , message = 'Expected ValueError' ):
8791 try :
8892 func ()
@@ -105,7 +109,7 @@ def test_logreg_train_and_predict():
105109 assert zero_pred < 0.2 , zero_pred
106110
107111
108- def test_logreg_predict_class_and_weight_io ():
112+ def test_logreg_weight_io_and_probabilities ():
109113 model = emlearn_logreg .new (2 , 0.1 , 0.0 , 0.0 )
110114
111115 manual_weights = array .array ('f' , [2.5 , - 1.5 ])
@@ -120,8 +124,8 @@ def test_logreg_predict_class_and_weight_io():
120124 bias = model .get_bias ()
121125 assert abs (bias + 0.5 ) < 1e-6 , bias
122126
123- assert model .predict_class (array .array ('f' , [2.0 , 0.0 ])) == 1
124- assert model .predict_class (array .array ('f' , [0.0 , 1.0 ])) == 0
127+ assert model .predict (array .array ('f' , [2.0 , 0.0 ])) > 0.5
128+ assert model .predict (array .array ('f' , [0.0 , 1.0 ])) < 0.5
125129
126130
127131def test_logreg_train_minibatch_reduces_loss ():
@@ -228,9 +232,44 @@ def test_logreg_train_requires_targets():
228232 assert_raises_value_error (lambda : emlearn_logreg .train (model , X , y ))
229233
230234
235+ def test_logreg_warm_start_sets_new_weights_and_bias ():
236+ X , y = make_dataset ()
237+ model = emlearn_logreg .new (2 , 0.3 , 0.0 , 0.0 )
238+
239+ emlearn_logreg .train (model , X , y , max_iterations = 20 , check_interval = 5 )
240+ trained_weights = read_weights (model )
241+ trained_bias = read_bias (model )
242+
243+ manual_model = emlearn_logreg .new (2 , 0.3 , 0.0 , 0.0 )
244+ manual_model .set_weights (trained_weights )
245+ manual_model .set_bias (trained_bias )
246+
247+ sample = array .array ('f' , [1.0 , 1.0 ])
248+ pred_trained = model .predict (sample )
249+ pred_manual = manual_model .predict (sample )
250+ assert abs (pred_trained - pred_manual ) < 1e-6
251+
252+
253+ def test_logreg_threshold_adjustment_behaviour ():
254+ model = emlearn_logreg .new (2 , 0.1 , 0.0 , 0.0 )
255+ weights = array .array ('f' , [5.0 , - 5.0 ])
256+ model .set_weights (weights )
257+ model .set_bias (- 1.0 )
258+
259+ features = array .array ('f' , [0.2 , 0.1 ])
260+ proba = model .predict (features )
261+ assert 0.0 < proba < 1.0
262+
263+ default_label = 1 if proba >= 0.5 else 0
264+ custom_threshold = 0.3
265+ custom_label = 1 if proba >= custom_threshold else 0
266+
267+ assert custom_label >= default_label
268+
269+
231270if __name__ == '__main__' :
232271 test_logreg_train_and_predict ()
233- test_logreg_predict_class_and_weight_io ()
272+ test_logreg_weight_io_and_probabilities ()
234273 test_logreg_train_minibatch_reduces_loss ()
235274 test_logreg_l2_penalty_shrinks_weights ()
236275 test_logreg_l1_penalty_promotes_sparsity ()
0 commit comments