Skip to content

Commit 03bd038

Browse files
committed
extratrees: Return probabilities in predict
1 parent f32c66e commit 03bd038

5 files changed

Lines changed: 41 additions & 59 deletions

File tree

src/emlearn_extratrees/extratrees.c

Lines changed: 13 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ static mp_obj_t extratrees_model_train_step(mp_obj_t self_obj) {
269269
static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_train_step_obj, extratrees_model_train_step);
270270

271271
// Predict using the model (returns class probabilities)
272-
static mp_obj_t extratrees_model_predict_proba(size_t n_args, const mp_obj_t *args) {
272+
static mp_obj_t extratrees_model_predict(size_t n_args, const mp_obj_t *args) {
273273
if (n_args != 3) {
274274
mp_raise_ValueError(MP_ERROR_TEXT("Expected 3 arguments: self, features, probabilities"));
275275
}
@@ -309,37 +309,9 @@ static mp_obj_t extratrees_model_predict_proba(size_t n_args, const mp_obj_t *ar
309309

310310
return mp_obj_new_int(predicted_class);
311311
}
312-
static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_predict_proba_obj, 3, 3, extratrees_model_predict_proba);
313-
314-
// Predict using the model (returns only class label)
315-
static mp_obj_t extratrees_model_predict(size_t n_args, const mp_obj_t *args) {
316-
if (n_args != 2) {
317-
mp_raise_ValueError(MP_ERROR_TEXT("Expected 2 arguments: self, features"));
318-
}
312+
static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_predict_obj, 3, 3, extratrees_model_predict);
319313

320-
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(args[0]);
321-
EmlExtraTreesModel *model = &o->model;
322-
EmlExtraTreesWorkspace *workspace = &o->workspace;
323-
324-
// Extract features buffer pointer and verify typecode
325-
mp_buffer_info_t features_bufinfo;
326-
mp_get_buffer_raise(args[1], &features_bufinfo, MP_BUFFER_READ);
327-
if (features_bufinfo.typecode != 'h') { // int16_t
328-
mp_raise_ValueError(MP_ERROR_TEXT("features expecting int16 array"));
329-
}
330-
const int16_t *features = features_bufinfo.buf;
331-
const int n_features = features_bufinfo.len / sizeof(int16_t);
332-
333-
if (n_features != model->n_features) {
334-
mp_raise_ValueError(MP_ERROR_TEXT("Feature count mismatch"));
335-
}
336-
337-
// Make prediction using pre-allocated workspace arrays
338-
int16_t predicted_class = eml_extratrees_predict_proba(model, features, workspace->probabilities, workspace->votes);
339-
340-
return mp_obj_new_int(predicted_class);
341-
}
342-
static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_predict_obj, 2, 2, extratrees_model_predict);
314+
// (old predict_proba removed - predict now handles both)
343315

344316
// Get number of features
345317
static mp_obj_t extratrees_model_get_n_features(mp_obj_t self_obj) {
@@ -396,7 +368,7 @@ static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_trees_trained_obj, extra
396368
#if MICROPY_ENABLE_DYNRUNTIME
397369

398370
// Module setup
399-
mp_map_elem_t extratrees_model_locals_dict_table[12];
371+
mp_map_elem_t extratrees_model_locals_dict_table[11];
400372
static MP_DEFINE_CONST_DICT(extratrees_model_locals_dict, extratrees_model_locals_dict_table);
401373

402374
// Module setup entrypoint
@@ -413,16 +385,15 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
413385
// methods
414386
extratrees_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_make_new), MP_OBJ_FROM_PTR(&extratrees_model_make_new_obj) };
415387
extratrees_model_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_OBJ_FROM_PTR(&extratrees_model_predict_obj) };
416-
extratrees_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict_proba), MP_OBJ_FROM_PTR(&extratrees_model_predict_proba_obj) };
417-
extratrees_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train), MP_OBJ_FROM_PTR(&extratrees_model_train_obj) };
418-
extratrees_model_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train_init), MP_OBJ_FROM_PTR(&extratrees_model_train_init_obj) };
419-
extratrees_model_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train_step), MP_OBJ_FROM_PTR(&extratrees_model_train_step_obj) };
420-
extratrees_model_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&extratrees_model_del_obj) };
421-
extratrees_model_locals_dict_table[7] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_features), MP_OBJ_FROM_PTR(&extratrees_model_get_n_features_obj) };
422-
extratrees_model_locals_dict_table[8] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_classes), MP_OBJ_FROM_PTR(&extratrees_model_get_n_classes_obj) };
423-
extratrees_model_locals_dict_table[9] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_obj) };
424-
extratrees_model_locals_dict_table[10] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_nodes_used), MP_OBJ_FROM_PTR(&extratrees_model_get_n_nodes_used_obj) };
425-
extratrees_model_locals_dict_table[11] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees_trained), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_trained_obj) };
388+
extratrees_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train), MP_OBJ_FROM_PTR(&extratrees_model_train_obj) };
389+
extratrees_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train_init), MP_OBJ_FROM_PTR(&extratrees_model_train_init_obj) };
390+
extratrees_model_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train_step), MP_OBJ_FROM_PTR(&extratrees_model_train_step_obj) };
391+
extratrees_model_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&extratrees_model_del_obj) };
392+
extratrees_model_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_features), MP_OBJ_FROM_PTR(&extratrees_model_get_n_features_obj) };
393+
extratrees_model_locals_dict_table[7] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_classes), MP_OBJ_FROM_PTR(&extratrees_model_get_n_classes_obj) };
394+
extratrees_model_locals_dict_table[8] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_obj) };
395+
extratrees_model_locals_dict_table[9] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_nodes_used), MP_OBJ_FROM_PTR(&extratrees_model_get_n_nodes_used_obj) };
396+
extratrees_model_locals_dict_table[10] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees_trained), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_trained_obj) };
426397

427398
MP_OBJ_TYPE_SET_SLOT(&extratrees_model_type, locals_dict, (void*)&extratrees_model_locals_dict, 10);
428399

@@ -436,7 +407,6 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
436407
static const mp_rom_map_elem_t extratrees_model_locals_dict_table[] = {
437408
{ MP_ROM_QSTR(MP_QSTR_make_new), MP_ROM_PTR(&extratrees_model_make_new_obj) },
438409
{ MP_ROM_QSTR(MP_QSTR_predict), MP_ROM_PTR(&extratrees_model_predict_obj) },
439-
{ MP_ROM_QSTR(MP_QSTR_predict_proba), MP_ROM_PTR(&extratrees_model_predict_proba_obj) },
440410
{ MP_ROM_QSTR(MP_QSTR_train), MP_ROM_PTR(&extratrees_model_train_obj) },
441411
{ MP_ROM_QSTR(MP_QSTR_train_init), MP_ROM_PTR(&extratrees_model_train_init_obj) },
442412
{ MP_ROM_QSTR(MP_QSTR_train_step), MP_ROM_PTR(&extratrees_model_train_step_obj) },

tests/test_extratrees.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
import array
33
import emlearn_extratrees
44

5+
6+
def argmax(arr, n):
7+
best_idx = 0
8+
best_val = arr[0]
9+
for i in range(1, n):
10+
if arr[i] > best_val:
11+
best_val = arr[i]
12+
best_idx = i
13+
return best_idx
14+
515
def test_single_tree_prediction():
616
"""Test with just one tree to isolate prediction issues"""
717
print("=== Single Tree Prediction Debug ===")
@@ -33,7 +43,7 @@ def test_single_tree_prediction():
3343
expected = y[i]
3444

3545
test_features = array.array('h', [x1, x2])
36-
predicted = model.predict_proba(test_features, probabilities)
46+
predicted = model.predict(test_features, probabilities)
3747

3848
print(" ({}, {}) -> pred={}, exp={}, probs=[{:.3f}, {:.3f}] {}".format(
3949
x1, x2, predicted, expected,
@@ -68,12 +78,11 @@ def test_class_bias():
6878

6979
# Test on a clear class 1 example
7080
test_features = array.array('h', [300, 300])
71-
predicted = model.predict(test_features)
81+
probabilities = array.array('f', [0.0, 0.0])
82+
model.predict(test_features, probabilities)
83+
predicted = argmax(probabilities, 2)
7284

7385
print("Prediction for (300,300): {} (should strongly favor class 1)".format(predicted))
74-
75-
probabilities = array.array('f', [0.0, 0.0])
76-
model.predict_proba(test_features, probabilities)
7786
print("Probabilities: [{:.3f}, {:.3f}]".format(probabilities[0], probabilities[1]))
7887

7988
def test_manual_verification():
@@ -122,7 +131,7 @@ def test_manual_verification():
122131
all_correct = True
123132
for x1_val, expected in test_cases:
124133
test_features = array.array('h', [x1_val, 500]) # x2 irrelevant
125-
predicted = model.predict_proba(test_features, probabilities)
134+
predicted = model.predict(test_features, probabilities)
126135

127136
is_correct = predicted == expected
128137
if not is_correct:
@@ -180,12 +189,12 @@ def test_train_step_by_step():
180189
# Verify predictions still work
181190
probabilities = array.array('f', [0.0, 0.0])
182191
test_features = array.array('h', [0, 0])
183-
predicted = model.predict_proba(test_features, probabilities)
192+
predicted = model.predict(test_features, probabilities)
184193
print("Predict (0,0): {} probs=[{:.3f}, {:.3f}]".format(predicted, probabilities[0], probabilities[1]))
185194
assert predicted == 0, "Expected class 0"
186195

187196
test_features = array.array('h', [300, 300])
188-
predicted = model.predict_proba(test_features, probabilities)
197+
predicted = model.predict(test_features, probabilities)
189198
print("Predict (300,300): {} probs=[{:.3f}, {:.3f}]".format(predicted, probabilities[0], probabilities[1]))
190199
assert predicted == 1, "Expected class 1"
191200

@@ -216,12 +225,15 @@ def test_train_generator():
216225
assert model.get_n_trees_trained() == 5
217226

218227
# Verify predictions
228+
probabilities = array.array('f', [0.0, 0.0])
219229
test_features = array.array('h', [0, 0])
220-
predicted = model.predict(test_features)
230+
model.predict(test_features, probabilities)
231+
predicted = argmax(probabilities, 2)
221232
assert predicted == 0, "Expected class 0"
222233

223234
test_features = array.array('h', [300, 300])
224-
predicted = model.predict(test_features)
235+
model.predict(test_features, probabilities)
236+
predicted = argmax(probabilities, 2)
225237
assert predicted == 1, "Expected class 1"
226238

227239
print("✓ Generator training works")

tests/test_extratrees_cancer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_real_dataset():
6666
end_idx = start_idx + n_features
6767
features = array.array('h', X_test_flat[start_idx:end_idx])
6868

69-
predicted = model.predict_proba(features, probabilities)
69+
predicted = model.predict(features, probabilities)
7070
actual = y_test[i]
7171

7272
if predicted == actual:

tests/test_extratrees_wine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_wine():
8282
end_idx = start_idx + n_features
8383
features = array.array('h', X_test_flat[start_idx:end_idx])
8484

85-
predicted = model.predict_proba(features, probabilities)
85+
predicted = model.predict(features, probabilities)
8686
actual = y_test[i]
8787

8888
# Track per-class stats

tests/test_extratrees_xor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_xor_comprehensive():
5252

5353
for features, expected in test_cases:
5454
test_features = array.array('h', features)
55-
predicted = model.predict_proba(test_features, probabilities)
55+
predicted = model.predict(test_features, probabilities)
5656
is_correct = predicted == expected
5757
if is_correct:
5858
correct += 1
@@ -76,7 +76,7 @@ def test_xor_comprehensive():
7676

7777
for features, expected in interpolation_cases:
7878
test_features = array.array('h', features)
79-
predicted = model.predict_proba(test_features, probabilities)
79+
predicted = model.predict(test_features, probabilities)
8080
confidence = max(probabilities[0], probabilities[1])
8181

8282
if expected == "?":
@@ -120,7 +120,7 @@ def test_xor_robustness():
120120

121121
for features, expected in test_cases:
122122
test_features = array.array('h', features)
123-
predicted = model.predict_proba(test_features, probabilities)
123+
predicted = model.predict(test_features, probabilities)
124124
if predicted == expected:
125125
correct += 1
126126

@@ -178,7 +178,7 @@ def test_xor_different_values():
178178

179179
for features, expected in test_cases:
180180
test_features = array.array('h', features)
181-
predicted = model.predict_proba(test_features, probabilities)
181+
predicted = model.predict(test_features, probabilities)
182182
if predicted == expected:
183183
correct += 1
184184

0 commit comments

Comments
 (0)