@@ -82,7 +82,7 @@ static mp_obj_t builder_new(mp_obj_t trees_obj, mp_obj_t nodes_obj, mp_obj_t lea
8282 self -> trees .n_trees = 0 ;
8383 self -> trees .tree_roots = roots ;
8484
85- self -> trees .leaf_bits = 0 ; // XXX: only class supported so far
85+ self -> trees .leaf_bits = 0 ; // default to majority voting
8686 self -> trees .n_leaves = 0 ;
8787 self -> trees .leaves = leaves ;
8888
@@ -114,17 +114,19 @@ static mp_obj_t builder_del(mp_obj_t trees_obj) {
114114static MP_DEFINE_CONST_FUN_OBJ_1 (builder_del_obj , builder_del ) ;
115115
116116// set number of features and classes
117- static mp_obj_t builder_setdata (mp_obj_t self_obj , mp_obj_t features_obj , mp_obj_t classes_obj ) {
117+ static mp_obj_t builder_setdata (size_t n_args , const mp_obj_t * args ) {
118118
119- mp_obj_trees_builder_t * o = MP_OBJ_TO_PTR (self_obj );
119+ //mp_obj_t self_obj, mp_obj_t features_obj, mp_obj_t classes_obj, mp_obj_t leaf_bits_obj
120+ mp_obj_trees_builder_t * o = MP_OBJ_TO_PTR (args [0 ]);
120121 EmlTreesBuilder * self = & o -> builder ;
121122
122- self -> trees .n_features = mp_obj_get_int (features_obj );
123- self -> trees .n_classes = mp_obj_get_int (classes_obj );
123+ self -> trees .n_features = mp_obj_get_int (args [1 ]);
124+ self -> trees .n_classes = mp_obj_get_int (args [2 ]);
125+ self -> trees .leaf_bits = mp_obj_get_int (args [3 ]);
124126
125127 return MP_OBJ_FROM_PTR (o );
126128}
127- static MP_DEFINE_CONST_FUN_OBJ_3 (builder_setdata_obj , builder_setdata ) ;
129+ static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN (builder_setdata_obj , 4 , 4 , builder_setdata ) ;
128130
129131
130132// Add a node to the tree
@@ -187,14 +189,22 @@ static mp_obj_t builder_addleaf(mp_obj_t self_obj, mp_obj_t leaf_obj) {
187189 mp_obj_trees_builder_t * o = MP_OBJ_TO_PTR (self_obj );
188190 EmlTreesBuilder * self = & o -> builder ;
189191
190- mp_int_t leaf_value = mp_obj_get_int (leaf_obj );
191-
192192 if (self -> trees .n_leaves >= self -> max_leaves ) {
193193 mp_raise_ValueError (MP_ERROR_TEXT ("max leaves" ));
194194 }
195195
196196 const int leaf_index = self -> trees .n_leaves ++ ;
197- self -> trees .leaves [leaf_index ] = (uint8_t )leaf_value ;
197+
198+ if (self -> trees .leaf_bits == 0 ) {
199+ // majority voting, leaf should be a single integer (class index)
200+ //mp_float_t leaf_value = mp_obj_get_float(leaf_obj);
201+ mp_int_t leaf_int = mp_obj_get_int (leaf_obj );
202+ self -> trees .leaves [leaf_index ] = (uint8_t )leaf_int ;
203+ } else if (self -> trees .leaf_bits == 32 ) {
204+ //const mp_float_t leaf_value = mp_obj_get_float(leaf_obj);
205+ float * leaves = (float * )self -> trees .leaves ;
206+ leaves [leaf_index ] = mp_obj_get_float_to_f (leaf_obj );
207+ }
198208
199209 return mp_const_none ;
200210 }
0 commit comments