Skip to content

Commit 01717c6

Browse files
committed
trees: Support float leaves for regression
NOTE: requires new emlearn which puts leaf_bits (lf) in .csv export
1 parent 7936c22 commit 01717c6

2 files changed

Lines changed: 37 additions & 11 deletions

File tree

src/emlearn_trees/emlearn_trees.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,49 @@ def load_model(builder, f):
1212
leaves_found = 0
1313
n_classes = None
1414
n_features = None
15+
leaf_bits = 0 # default if not specified
1516

1617
for line in f:
1718
line = line.rstrip('\r')
1819
line = line.rstrip('\n')
1920
tok = line.split(',')
2021
kind = tok[0]
22+
23+
missing = leaf_bits is None or n_features is None or n_classes is None
24+
2125
if kind == 'r':
26+
assert not missing, 'missing metadata before roots'
2227
root = int(tok[1])
2328
builder.addroot(root)
2429
elif kind == 'n':
30+
assert not missing, 'missing metadata before nodes'
2531
feature = int(tok[1])
2632
value = int(float(tok[2]))
2733
left = int(tok[3])
2834
right = int(tok[4])
2935
builder.addnode(left, right, feature, value)
3036
elif kind == 'l':
31-
leaf = int(tok[1])
37+
assert not missing, 'missing metadata before leaves'
38+
assert len(tok) == 2, len(tok)
39+
if leaf_bits == 32:
40+
leaf = float(tok[1])
41+
else:
42+
leaf = int(tok[1])
3243
builder.addleaf(leaf)
3344
leaves_found += 1
45+
# metadata
3446
elif kind == 'f':
3547
n_features = int(tok[1])
3648
elif kind == 'c':
3749
n_classes = int(tok[1])
50+
elif kind == 'lf':
51+
leaf_bits = int(tok[1])
3852
else:
3953
# unknown value
4054
pass
4155

42-
builder.setdata(n_features, n_classes)
56+
if not missing:
57+
# FIXME: pass leaf_bits
58+
builder.setdata(n_features, n_classes, leaf_bits)
4359

4460
#print('load-model', leaves_found)

src/emlearn_trees/trees.c

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ static mp_obj_t builder_new(mp_obj_t trees_obj, mp_obj_t nodes_obj, mp_obj_t lea
8282
self->trees.n_trees = 0;
8383
self->trees.tree_roots = roots;
8484

85-
self->trees.leaf_bits = 0; // XXX: only class supported so far
85+
self->trees.leaf_bits = 0; // default to majority voting
8686
self->trees.n_leaves = 0;
8787
self->trees.leaves = leaves;
8888

@@ -114,17 +114,19 @@ static mp_obj_t builder_del(mp_obj_t trees_obj) {
114114
static MP_DEFINE_CONST_FUN_OBJ_1(builder_del_obj, builder_del);
115115

116116
// set number of features and classes
117-
static mp_obj_t builder_setdata(mp_obj_t self_obj, mp_obj_t features_obj, mp_obj_t classes_obj) {
117+
static mp_obj_t builder_setdata(size_t n_args, const mp_obj_t *args) {
118118

119-
mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj);
119+
//mp_obj_t self_obj, mp_obj_t features_obj, mp_obj_t classes_obj, mp_obj_t leaf_bits_obj
120+
mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(args[0]);
120121
EmlTreesBuilder *self = &o->builder;
121122

122-
self->trees.n_features = mp_obj_get_int(features_obj);
123-
self->trees.n_classes = mp_obj_get_int(classes_obj);
123+
self->trees.n_features = mp_obj_get_int(args[1]);
124+
self->trees.n_classes = mp_obj_get_int(args[2]);
125+
self->trees.leaf_bits = mp_obj_get_int(args[3]);
124126

125127
return MP_OBJ_FROM_PTR(o);
126128
}
127-
static MP_DEFINE_CONST_FUN_OBJ_3(builder_setdata_obj, builder_setdata);
129+
static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(builder_setdata_obj, 4, 4, builder_setdata);
128130

129131

130132
// Add a node to the tree
@@ -187,14 +189,22 @@ static mp_obj_t builder_addleaf(mp_obj_t self_obj, mp_obj_t leaf_obj) {
187189
mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj);
188190
EmlTreesBuilder *self = &o->builder;
189191

190-
mp_int_t leaf_value = mp_obj_get_int(leaf_obj);
191-
192192
if (self->trees.n_leaves >= self->max_leaves) {
193193
mp_raise_ValueError(MP_ERROR_TEXT("max leaves"));
194194
}
195195

196196
const int leaf_index = self->trees.n_leaves++;
197-
self->trees.leaves[leaf_index] = (uint8_t)leaf_value;
197+
198+
if (self->trees.leaf_bits == 0) {
199+
// majority voting, leaf should be a single integer (class index)
200+
//mp_float_t leaf_value = mp_obj_get_float(leaf_obj);
201+
mp_int_t leaf_int = mp_obj_get_int(leaf_obj);
202+
self->trees.leaves[leaf_index] = (uint8_t)leaf_int;
203+
} else if (self->trees.leaf_bits == 32) {
204+
//const mp_float_t leaf_value = mp_obj_get_float(leaf_obj);
205+
float *leaves = (float *)self->trees.leaves;
206+
leaves[leaf_index] = mp_obj_get_float_to_f(leaf_obj);
207+
}
198208

199209
return mp_const_none;
200210
}

0 commit comments

Comments
 (0)