Skip to content

Commit dd6436c

Browse files
committed
logreg: Remove predict_class
Consumers should use the continious values
1 parent 63d2f0b commit dd6436c

3 files changed

Lines changed: 54 additions & 40 deletions

File tree

src/emlearn_logreg/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Location of top-level MicroPython directory
2-
MPY_DIR = ../../micropython
2+
MPY_DIR ?= ../../micropython
33

44
# Architecture to build for (x86, x64, armv6m, armv7m, xtensa, xtensawin)
55
ARCH = x64

src/emlearn_logreg/logreg.c

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -137,30 +137,6 @@ static mp_obj_t logreg_model_predict(mp_obj_t self_obj, mp_obj_t features_obj) {
137137
}
138138
static MP_DEFINE_CONST_FUN_OBJ_2(logreg_model_predict_obj, logreg_model_predict);
139139

140-
static mp_obj_t logreg_model_predict_class(mp_obj_t self_obj, mp_obj_t features_obj) {
141-
mp_obj_logreg_model_t *o = MP_OBJ_TO_PTR(self_obj);
142-
logreg_model_t *self = &o->model;
143-
144-
mp_buffer_info_t bufinfo;
145-
mp_get_buffer_raise(features_obj, &bufinfo, MP_BUFFER_READ);
146-
if (bufinfo.typecode != 'f') {
147-
mp_raise_ValueError(MP_ERROR_TEXT("expecting float32 array"));
148-
}
149-
const float *features = bufinfo.buf;
150-
const int n_features = bufinfo.len / sizeof(float);
151-
152-
if (n_features != self->n_features) {
153-
mp_raise_ValueError(MP_ERROR_TEXT("Feature count mismatch"));
154-
}
155-
156-
const float threshold = 0.5f;
157-
158-
uint8_t label = logreg_predict_proba(self, features) >= threshold ? 1 : 0;
159-
160-
return mp_obj_new_int(label);
161-
}
162-
static MP_DEFINE_CONST_FUN_OBJ_2(logreg_model_predict_class_obj, logreg_model_predict_class);
163-
164140
// Get model weights
165141
static mp_obj_t logreg_model_get_weights(mp_obj_t self_obj, mp_obj_t out_obj) {
166142
mp_obj_logreg_model_t *o = MP_OBJ_TO_PTR(self_obj);
@@ -279,17 +255,16 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
279255
logreg_model_type.name = MP_QSTR_logreg;
280256

281257
logreg_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_OBJ_FROM_PTR(&logreg_model_predict_obj) };
282-
logreg_model_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict_class), MP_OBJ_FROM_PTR(&logreg_model_predict_class_obj) };
283-
logreg_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_step), MP_OBJ_FROM_PTR(&logreg_model_step_obj) };
284-
logreg_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&logreg_model_del_obj) };
285-
logreg_model_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_weights), MP_OBJ_FROM_PTR(&logreg_model_get_weights_obj) };
286-
logreg_model_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_set_weights), MP_OBJ_FROM_PTR(&logreg_model_set_weights_obj) };
287-
logreg_model_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_bias), MP_OBJ_FROM_PTR(&logreg_model_get_bias_obj) };
288-
logreg_model_locals_dict_table[7] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_set_bias), MP_OBJ_FROM_PTR(&logreg_model_set_bias_obj) };
289-
logreg_model_locals_dict_table[8] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_features), MP_OBJ_FROM_PTR(&logreg_model_get_n_features_obj) };
290-
logreg_model_locals_dict_table[9] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_score_logloss), MP_OBJ_FROM_PTR(&logreg_model_score_logloss_obj) };
291-
292-
MP_OBJ_TYPE_SET_SLOT(&logreg_model_type, locals_dict, (void *)&logreg_model_locals_dict, 10);
258+
logreg_model_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_step), MP_OBJ_FROM_PTR(&logreg_model_step_obj) };
259+
logreg_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&logreg_model_del_obj) };
260+
logreg_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_weights), MP_OBJ_FROM_PTR(&logreg_model_get_weights_obj) };
261+
logreg_model_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_set_weights), MP_OBJ_FROM_PTR(&logreg_model_set_weights_obj) };
262+
logreg_model_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_bias), MP_OBJ_FROM_PTR(&logreg_model_get_bias_obj) };
263+
logreg_model_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_set_bias), MP_OBJ_FROM_PTR(&logreg_model_set_bias_obj) };
264+
logreg_model_locals_dict_table[7] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_features), MP_OBJ_FROM_PTR(&logreg_model_get_n_features_obj) };
265+
logreg_model_locals_dict_table[8] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_score_logloss), MP_OBJ_FROM_PTR(&logreg_model_score_logloss_obj) };
266+
267+
MP_OBJ_TYPE_SET_SLOT(&logreg_model_type, locals_dict, (void *)&logreg_model_locals_dict, 9);
293268

294269
MP_DYNRUNTIME_INIT_EXIT
295270
}

tests/test_logreg.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def read_weights(model):
8383
return buf
8484

8585

86+
def read_bias(model):
87+
return model.get_bias()
88+
89+
8690
def assert_raises_value_error(func, message='Expected ValueError'):
8791
try:
8892
func()
@@ -105,7 +109,7 @@ def test_logreg_train_and_predict():
105109
assert zero_pred < 0.2, zero_pred
106110

107111

108-
def test_logreg_predict_class_and_weight_io():
112+
def test_logreg_weight_io_and_probabilities():
109113
model = emlearn_logreg.new(2, 0.1, 0.0, 0.0)
110114

111115
manual_weights = array.array('f', [2.5, -1.5])
@@ -120,8 +124,8 @@ def test_logreg_predict_class_and_weight_io():
120124
bias = model.get_bias()
121125
assert abs(bias + 0.5) < 1e-6, bias
122126

123-
assert model.predict_class(array.array('f', [2.0, 0.0])) == 1
124-
assert model.predict_class(array.array('f', [0.0, 1.0])) == 0
127+
assert model.predict(array.array('f', [2.0, 0.0])) > 0.5
128+
assert model.predict(array.array('f', [0.0, 1.0])) < 0.5
125129

126130

127131
def test_logreg_train_minibatch_reduces_loss():
@@ -228,9 +232,44 @@ def test_logreg_train_requires_targets():
228232
assert_raises_value_error(lambda: emlearn_logreg.train(model, X, y))
229233

230234

235+
def test_logreg_warm_start_sets_new_weights_and_bias():
236+
X, y = make_dataset()
237+
model = emlearn_logreg.new(2, 0.3, 0.0, 0.0)
238+
239+
emlearn_logreg.train(model, X, y, max_iterations=20, check_interval=5)
240+
trained_weights = read_weights(model)
241+
trained_bias = read_bias(model)
242+
243+
manual_model = emlearn_logreg.new(2, 0.3, 0.0, 0.0)
244+
manual_model.set_weights(trained_weights)
245+
manual_model.set_bias(trained_bias)
246+
247+
sample = array.array('f', [1.0, 1.0])
248+
pred_trained = model.predict(sample)
249+
pred_manual = manual_model.predict(sample)
250+
assert abs(pred_trained - pred_manual) < 1e-6
251+
252+
253+
def test_logreg_threshold_adjustment_behaviour():
254+
model = emlearn_logreg.new(2, 0.1, 0.0, 0.0)
255+
weights = array.array('f', [5.0, -5.0])
256+
model.set_weights(weights)
257+
model.set_bias(-1.0)
258+
259+
features = array.array('f', [0.2, 0.1])
260+
proba = model.predict(features)
261+
assert 0.0 < proba < 1.0
262+
263+
default_label = 1 if proba >= 0.5 else 0
264+
custom_threshold = 0.3
265+
custom_label = 1 if proba >= custom_threshold else 0
266+
267+
assert custom_label >= default_label
268+
269+
231270
if __name__ == '__main__':
232271
test_logreg_train_and_predict()
233-
test_logreg_predict_class_and_weight_io()
272+
test_logreg_weight_io_and_probabilities()
234273
test_logreg_train_minibatch_reduces_loss()
235274
test_logreg_l2_penalty_shrinks_weights()
236275
test_logreg_l1_penalty_promotes_sparsity()

0 commit comments

Comments
 (0)