Skip to content

Commit f32c66e

Browse files
committed
extratrees: Use keyword arguments for new
1 parent ae352b4 commit f32c66e

6 files changed

Lines changed: 82 additions & 73 deletions

File tree

src/emlearn_extratrees/emlearn_extratrees.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,36 @@
77
pass
88

99

10+
def new(n_features, n_classes, *,
11+
n_trees=10, max_depth=10, min_samples_leaf=1,
12+
n_thresholds=10, subsample_ratio=1.0, feature_subsample_ratio=1.0,
13+
max_nodes=1000, max_samples=1000,
14+
rng_seed=42, use_global_feature_range=False):
15+
"""Create a new ExtraTrees model.
16+
17+
Args:
18+
n_features: Number of input features (required)
19+
n_classes: Number of output classes (required)
20+
n_trees: Number of trees in the ensemble
21+
max_depth: Maximum tree depth
22+
min_samples_leaf: Minimum samples at a leaf node
23+
n_thresholds: Random thresholds drawn per feature split
24+
subsample_ratio: Fraction of samples used per tree (0.0-1.0)
25+
feature_subsample_ratio: Fraction of features considered per split (0.0-1.0)
26+
max_nodes: Max pre-allocated nodes (memory)
27+
max_samples: Max pre-allocated samples (memory)
28+
rng_seed: Random number generator seed
29+
use_global_feature_range: Use global feature min/max instead of per-node
30+
"""
31+
return make_new(
32+
n_features, n_classes,
33+
n_trees, max_depth, min_samples_leaf, n_thresholds,
34+
subsample_ratio, feature_subsample_ratio,
35+
max_nodes, max_samples,
36+
rng_seed, int(use_global_feature_range)
37+
)
38+
39+
1040
def train_steps(model, X, y):
1141
"""Generator that trains one node at a time, yielding after each step.
1242

src/emlearn_extratrees/extratrees.c

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,24 @@ static const mp_obj_type_t extratrees_model_type;
3737
#endif
3838

3939
// Create a new instance
40-
static mp_obj_t extratrees_model_new(size_t n_args, const mp_obj_t *args) {
41-
// Args: n_features, n_classes, [n_trees], [max_depth], [min_samples_leaf], [n_thresholds],
42-
// [subsample_ratio], [feature_subsample_ratio], [max_nodes], [max_samples], [rng_seed], [use_global_feature_range]
43-
if (n_args < 2 || n_args > 12) {
44-
mp_raise_ValueError(MP_ERROR_TEXT("Expected 2-12 arguments: n_features, n_classes, [n_trees=10], [max_depth=10], [min_samples_leaf=1], [n_thresholds=10], [subsample_ratio=1.0], [feature_subsample_ratio=1.0], [max_nodes=1000], [max_samples=1000], [rng_seed=42], [use_global_feature_range=0]"));
40+
static mp_obj_t extratrees_model_make_new(size_t n_args, const mp_obj_t *args) {
41+
// All 12 args provided by Python wrapper (defaults handled there)
42+
if (n_args != 12) {
43+
mp_raise_ValueError(MP_ERROR_TEXT("Expected 12 arguments"));
4544
}
4645

4746
mp_int_t n_features = mp_obj_get_int(args[0]);
4847
mp_int_t n_classes = mp_obj_get_int(args[1]);
49-
mp_int_t n_trees = (n_args > 2) ? mp_obj_get_int(args[2]) : 10;
50-
mp_int_t max_depth = (n_args > 3) ? mp_obj_get_int(args[3]) : 10;
51-
mp_int_t min_samples_leaf = (n_args > 4) ? mp_obj_get_int(args[4]) : 1;
52-
mp_int_t n_thresholds = (n_args > 5) ? mp_obj_get_int(args[5]) : 10;
53-
float subsample_ratio = (n_args > 6) ? mp_obj_get_float_to_f(args[6]) : 1.0f;
54-
float feature_subsample_ratio = (n_args > 7) ? mp_obj_get_float_to_f(args[7]) : 1.0f;
55-
mp_int_t max_nodes = (n_args > 8) ? mp_obj_get_int(args[8]) : 1000;
56-
mp_int_t max_samples = (n_args > 9) ? mp_obj_get_int(args[9]) : 1000;
57-
mp_int_t rng_seed = (n_args > 10) ? mp_obj_get_int(args[10]) : 42;
58-
mp_int_t use_global_feature_range = (n_args > 11) ? mp_obj_get_int(args[11]) : 0;
48+
mp_int_t n_trees = mp_obj_get_int(args[2]);
49+
mp_int_t max_depth = mp_obj_get_int(args[3]);
50+
mp_int_t min_samples_leaf = mp_obj_get_int(args[4]);
51+
mp_int_t n_thresholds = mp_obj_get_int(args[5]);
52+
float subsample_ratio = mp_obj_get_float_to_f(args[6]);
53+
float feature_subsample_ratio = mp_obj_get_float_to_f(args[7]);
54+
mp_int_t max_nodes = mp_obj_get_int(args[8]);
55+
mp_int_t max_samples = mp_obj_get_int(args[9]);
56+
mp_int_t rng_seed = mp_obj_get_int(args[10]);
57+
mp_int_t use_global_feature_range = mp_obj_get_int(args[11]);
5958

6059
// Allocate space
6160
mp_obj_extratrees_model_t *o = \
@@ -116,7 +115,7 @@ static mp_obj_t extratrees_model_new(size_t n_args, const mp_obj_t *args) {
116115
return MP_OBJ_FROM_PTR(o);
117116
}
118117
// Define a Python reference to the function above
119-
static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_new_obj, 2, 12, extratrees_model_new);
118+
static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_make_new_obj, 12, 12, extratrees_model_make_new);
120119

121120
// Delete an instance
122121
static mp_obj_t extratrees_model_del(mp_obj_t self_obj) {
@@ -397,32 +396,33 @@ static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_trees_trained_obj, extra
397396
#if MICROPY_ENABLE_DYNRUNTIME
398397

399398
// Module setup
400-
mp_map_elem_t extratrees_model_locals_dict_table[11];
399+
mp_map_elem_t extratrees_model_locals_dict_table[12];
401400
static MP_DEFINE_CONST_DICT(extratrees_model_locals_dict, extratrees_model_locals_dict_table);
402401

403402
// Module setup entrypoint
404403
mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *args) {
405404
// This must be first, it sets up the globals dict and other things
406405
MP_DYNRUNTIME_INIT_ENTRY
407406

408-
mp_store_global(MP_QSTR_new, MP_OBJ_FROM_PTR(&extratrees_model_new_obj));
407+
mp_store_global(MP_QSTR_make_new, MP_OBJ_FROM_PTR(&extratrees_model_make_new_obj));
409408

410409
extratrees_model_type.base.type = (void*)&mp_fun_table.type_type;
411410
extratrees_model_type.flags = MP_TYPE_FLAG_ITER_IS_CUSTOM;
412411
extratrees_model_type.name = MP_QSTR_extratrees;
413412

414413
// methods
415-
extratrees_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_OBJ_FROM_PTR(&extratrees_model_predict_obj) };
416-
extratrees_model_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict_proba), MP_OBJ_FROM_PTR(&extratrees_model_predict_proba_obj) };
417-
extratrees_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train), MP_OBJ_FROM_PTR(&extratrees_model_train_obj) };
418-
extratrees_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train_init), MP_OBJ_FROM_PTR(&extratrees_model_train_init_obj) };
419-
extratrees_model_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train_step), MP_OBJ_FROM_PTR(&extratrees_model_train_step_obj) };
420-
extratrees_model_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&extratrees_model_del_obj) };
421-
extratrees_model_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_features), MP_OBJ_FROM_PTR(&extratrees_model_get_n_features_obj) };
422-
extratrees_model_locals_dict_table[7] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_classes), MP_OBJ_FROM_PTR(&extratrees_model_get_n_classes_obj) };
423-
extratrees_model_locals_dict_table[8] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_obj) };
424-
extratrees_model_locals_dict_table[9] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_nodes_used), MP_OBJ_FROM_PTR(&extratrees_model_get_n_nodes_used_obj) };
425-
extratrees_model_locals_dict_table[10] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees_trained), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_trained_obj) };
414+
extratrees_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_make_new), MP_OBJ_FROM_PTR(&extratrees_model_make_new_obj) };
415+
extratrees_model_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_OBJ_FROM_PTR(&extratrees_model_predict_obj) };
416+
extratrees_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict_proba), MP_OBJ_FROM_PTR(&extratrees_model_predict_proba_obj) };
417+
extratrees_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train), MP_OBJ_FROM_PTR(&extratrees_model_train_obj) };
418+
extratrees_model_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train_init), MP_OBJ_FROM_PTR(&extratrees_model_train_init_obj) };
419+
extratrees_model_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_train_step), MP_OBJ_FROM_PTR(&extratrees_model_train_step_obj) };
420+
extratrees_model_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&extratrees_model_del_obj) };
421+
extratrees_model_locals_dict_table[7] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_features), MP_OBJ_FROM_PTR(&extratrees_model_get_n_features_obj) };
422+
extratrees_model_locals_dict_table[8] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_classes), MP_OBJ_FROM_PTR(&extratrees_model_get_n_classes_obj) };
423+
extratrees_model_locals_dict_table[9] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_obj) };
424+
extratrees_model_locals_dict_table[10] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_nodes_used), MP_OBJ_FROM_PTR(&extratrees_model_get_n_nodes_used_obj) };
425+
extratrees_model_locals_dict_table[11] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_n_trees_trained), MP_OBJ_FROM_PTR(&extratrees_model_get_n_trees_trained_obj) };
426426

427427
MP_OBJ_TYPE_SET_SLOT(&extratrees_model_type, locals_dict, (void*)&extratrees_model_locals_dict, 10);
428428

@@ -434,6 +434,7 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
434434

435435
// User C module mode
436436
static const mp_rom_map_elem_t extratrees_model_locals_dict_table[] = {
437+
{ MP_ROM_QSTR(MP_QSTR_make_new), MP_ROM_PTR(&extratrees_model_make_new_obj) },
437438
{ MP_ROM_QSTR(MP_QSTR_predict), MP_ROM_PTR(&extratrees_model_predict_obj) },
438439
{ MP_ROM_QSTR(MP_QSTR_predict_proba), MP_ROM_PTR(&extratrees_model_predict_proba_obj) },
439440
{ MP_ROM_QSTR(MP_QSTR_train), MP_ROM_PTR(&extratrees_model_train_obj) },
@@ -457,7 +458,7 @@ static MP_DEFINE_CONST_OBJ_TYPE(
457458

458459
// Define module object
459460
static const mp_rom_map_elem_t emlearn_extratrees_globals_table[] = {
460-
{ MP_ROM_QSTR(MP_QSTR_new), MP_ROM_PTR(&extratrees_model_new_obj) },
461+
{ MP_ROM_QSTR(MP_QSTR_make_new), MP_ROM_PTR(&extratrees_model_make_new_obj) },
461462
};
462463
static MP_DEFINE_CONST_DICT(emlearn_extratrees_globals, emlearn_extratrees_globals_table);
463464

tests/test_extratrees.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_single_tree_prediction():
1616
print("Data: (0,0)->0, (1000,1000)->1")
1717

1818
# Single tree with enough depth and thresholds
19-
model = emlearn_extratrees.new(2, 2, 1, 5, 1, 100)
19+
model = emlearn_extratrees.new(2, 2, n_trees=1, max_depth=5, n_thresholds=100)
2020
model.train(X, y)
2121

2222
nodes_used = model.get_n_nodes_used()
@@ -63,7 +63,7 @@ def test_class_bias():
6363

6464
print("Data: 1 sample class 0, 5 samples class 1")
6565

66-
model = emlearn_extratrees.new(2, 2, 3, 5, 1, 50)
66+
model = emlearn_extratrees.new(2, 2, n_trees=3, max_depth=5, n_thresholds=50)
6767
model.train(X, y)
6868

6969
# Test on a clear class 1 example
@@ -103,7 +103,7 @@ def test_manual_verification():
103103
print(" ({}, {}) -> {}".format(x1, x2, label))
104104

105105
# Train with parameters that should definitely work
106-
model = emlearn_extratrees.new(2, 2, 10, 10, 1, 100)
106+
model = emlearn_extratrees.new(2, 2, n_trees=10, max_depth=10, n_thresholds=100)
107107
model.train(X, y)
108108

109109
print("\nNodes used: {}".format(model.get_n_nodes_used()))
@@ -160,7 +160,7 @@ def test_train_step_by_step():
160160
y = array.array('h', [0, 0, 1, 1])
161161

162162
# 3 trees
163-
model = emlearn_extratrees.new(2, 2, 3, 5, 1, 50)
163+
model = emlearn_extratrees.new(2, 2, n_trees=3, max_depth=5, n_thresholds=50)
164164
model.train_init(X, y)
165165

166166
steps = 0
@@ -203,7 +203,7 @@ def test_train_generator():
203203
])
204204
y = array.array('h', [0, 0, 1, 1])
205205

206-
model = emlearn_extratrees.new(2, 2, 5, 5, 1, 50)
206+
model = emlearn_extratrees.new(2, 2, n_trees=5, max_depth=5, n_thresholds=50)
207207

208208
trees_progress = []
209209
for trees_done in emlearn_extratrees.train_steps(model, X, y):
@@ -239,11 +239,11 @@ def test_train_step_same_as_train():
239239
y = array.array('h', [0, 0, 1, 1])
240240

241241
# Bulk train
242-
model_bulk = emlearn_extratrees.new(2, 2, 3, 5, 1, 50, 1.0, 1.0, 1000, 100, 42)
242+
model_bulk = emlearn_extratrees.new(2, 2, n_trees=3, max_depth=5, n_thresholds=50, max_samples=100, rng_seed=42)
243243
model_bulk.train(X, y)
244244

245245
# Step-by-step train
246-
model_step = emlearn_extratrees.new(2, 2, 3, 5, 1, 50, 1.0, 1.0, 1000, 100, 42)
246+
model_step = emlearn_extratrees.new(2, 2, n_trees=3, max_depth=5, n_thresholds=50, max_samples=100, rng_seed=42)
247247
model_step.train_init(X, y)
248248
while not model_step.train_step():
249249
pass

tests/test_extratrees_cancer.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,10 @@ def test_real_dataset():
4747

4848
# Create model
4949
model = emlearn_extratrees.new(
50-
30, # n_features
51-
2, # n_classes
52-
20, # n_trees
53-
12, # max_depth
54-
2, # min_samples_leaf
55-
15, # n_thresholds
56-
0.8, # subsample_ratio
57-
0.7, # feature_subsample_ratio
58-
3000, # max_nodes
59-
500, # max_samples
60-
42 # rng_seed
50+
30, 2,
51+
n_trees=20, max_depth=12, min_samples_leaf=2,
52+
n_thresholds=15, subsample_ratio=0.8, feature_subsample_ratio=0.7,
53+
max_nodes=3000, max_samples=500, rng_seed=42
6154
)
6255

6356
print("Training...")

tests/test_extratrees_wine.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,10 @@ def test_wine():
5656

5757
# Create model
5858
model = emlearn_extratrees.new(
59-
n_features, # n_features
60-
n_classes, # n_classes
61-
20, # n_trees
62-
12, # max_depth
63-
2, # min_samples_leaf
64-
15, # n_thresholds
65-
0.8, # subsample_ratio
66-
1.0, # feature_subsample_ratio
67-
3000, # max_nodes
68-
500, # max_samples
69-
42 # rng_seed
59+
n_features, n_classes,
60+
n_trees=20, max_depth=12, min_samples_leaf=2,
61+
n_thresholds=15, subsample_ratio=0.8, feature_subsample_ratio=1.0,
62+
max_nodes=3000, max_samples=500, rng_seed=42
7063
)
7164

7265
train_start = time.ticks_ms()

tests/test_extratrees_xor.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,10 @@ def test_xor_comprehensive():
2929

3030
# Test with ensemble of trees (now that individual trees work)
3131
model = emlearn_extratrees.new(
32-
2, # n_features
33-
2, # n_classes
34-
10, # n_trees (ensemble)
35-
8, # max_depth
36-
1, # min_samples_leaf
37-
10, # n_thresholds
38-
0.8, # subsample_ratio (80% for diversity)
39-
1.0, # feature_subsample_ratio (use both features)
40-
500, # max_nodes
41-
100, # max_samples
42-
42 # rng_seed
43-
)
32+
2, 2,
33+
n_trees=10, max_depth=8, n_thresholds=10,
34+
subsample_ratio=0.8, feature_subsample_ratio=1.0,
35+
max_nodes=500, max_samples=100, rng_seed=42)
4436

4537
model.train(X, y)
4638

@@ -118,7 +110,7 @@ def test_xor_robustness():
118110
for n_trees, max_depth, desc in configs:
119111
print(f"\nTesting {desc}:")
120112

121-
model = emlearn_extratrees.new(2, 2, n_trees, max_depth, 1, 8, 0.9, 1.0, 1000, 100, 123)
113+
model = emlearn_extratrees.new(2, 2, n_trees=n_trees, max_depth=max_depth, n_thresholds=8, subsample_ratio=0.9, max_samples=100, rng_seed=123)
122114
model.train(X, y)
123115

124116
# Test all XOR cases
@@ -170,7 +162,7 @@ def test_xor_different_values():
170162
X = array.array('h', X_data)
171163
y = array.array('h', y_data)
172164

173-
model = emlearn_extratrees.new(2, 2, 12, 10, 1, 10, 0.8, 1.0, 800, 100, 456)
165+
model = emlearn_extratrees.new(2, 2, n_trees=12, max_depth=10, n_thresholds=10, subsample_ratio=0.8, max_nodes=800, max_samples=100, rng_seed=456)
174166
model.train(X, y)
175167

176168
# Test

0 commit comments

Comments
 (0)