Skip to content

Commit ae352b4

Browse files
committed
extratrees: Use proper API prefix for C code
1 parent d6b7d96 commit ae352b4

2 files changed

Lines changed: 52 additions & 52 deletions

File tree

src/emlearn_extratrees/eml_extratrees.c

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
#define printf(fmt, ...) ((void)0)
1313
#endif
1414

15-
typedef struct _EmlTreesNode {
15+
typedef struct _EmlExtraTreesNode {
1616
int8_t feature; // -1 for leaf nodes
1717
int16_t value; // threshold or class label
1818
int16_t left; // left child index
1919
int16_t right; // right child index
20-
} EmlTreesNode;
20+
} EmlExtraTreesNode;
2121

2222
typedef struct _NodeState {
2323
int16_t node_idx; // current node being processed
@@ -26,18 +26,18 @@ typedef struct _NodeState {
2626
int16_t depth; // current depth
2727
} NodeState;
2828

29-
typedef struct _EmlTreesConfig {
29+
typedef struct _EmlExtraTreesConfig {
3030
int16_t max_depth;
3131
int16_t min_samples_leaf;
3232
int16_t n_thresholds;
3333
float subsample_ratio; // subsample ratio as float (0.0 to 1.0)
3434
float feature_subsample_ratio; // feature subsample ratio as float (0.0 to 1.0)
3535
int16_t use_global_feature_range; // 0: per-node min/max, 1: global min/max
3636
uint32_t rng_seed;
37-
} EmlTreesConfig;
37+
} EmlExtraTreesConfig;
3838

39-
typedef struct _EmlTreesModel {
40-
EmlTreesNode *nodes; // Pre-allocated node array
39+
typedef struct _EmlExtraTreesModel {
40+
EmlExtraTreesNode *nodes; // Pre-allocated node array
4141
int16_t *tree_starts; // Start index for each tree
4242
int16_t max_nodes; // Maximum nodes available
4343
int16_t max_samples; // Maximum samples in training data
@@ -46,10 +46,10 @@ typedef struct _EmlTreesModel {
4646
int16_t n_classes; // Number of classes
4747
int16_t n_trees; // Number of trees
4848
int16_t n_trees_trained; // Number of trees fully trained
49-
EmlTreesConfig config;
50-
} EmlTreesModel;
49+
EmlExtraTreesConfig config;
50+
} EmlExtraTreesModel;
5151

52-
typedef struct _EmlTreesWorkspace {
52+
typedef struct _EmlExtraTreesWorkspace {
5353
uint16_t *sample_indices; // Sample indices for current tree
5454
uint16_t *feature_indices; // Feature indices for current tree
5555
int16_t *min_vals; // Min values per feature [n_features]
@@ -71,7 +71,7 @@ typedef struct _EmlTreesWorkspace {
7171
int16_t train_state; // 0=idle, 1=training, 2=done
7272
const int16_t *train_features; // Pointer to training features data
7373
const int16_t *train_labels; // Pointer to training labels data
74-
} EmlTreesWorkspace;
74+
} EmlExtraTreesWorkspace;
7575

7676
// Simple linear congruential generator
7777
static uint32_t eml_rand(uint32_t *state) {
@@ -108,8 +108,8 @@ static float calculate_gini_from_counts(const int16_t *counts, int16_t total, in
108108

109109

110110
// Partition samples based on feature threshold
111-
static int partition_samples(const int16_t *features, EmlTreesModel *model,
112-
EmlTreesWorkspace *workspace, int start, int end,
111+
static int partition_samples(const int16_t *features, EmlExtraTreesModel *model,
112+
EmlExtraTreesWorkspace *workspace, int start, int end,
113113
int8_t feature, int threshold) {
114114
int left = start;
115115
int right = end - 1;
@@ -142,8 +142,8 @@ static int partition_samples(const int16_t *features, EmlTreesModel *model,
142142

143143

144144

145-
// Add this debug version of eml_trees_predict_proba
146-
static int16_t eml_trees_predict_proba(const EmlTreesModel *model, const int16_t *features,
145+
// Add this debug version of eml_extratrees_predict_proba
146+
static int16_t eml_extratrees_predict_proba(const EmlExtraTreesModel *model, const int16_t *features,
147147
float *probabilities, int16_t *votes) {
148148

149149
// Initialize vote counts
@@ -259,7 +259,7 @@ static int get_majority_class(const int16_t *labels, const uint16_t *indices,
259259
// but will lead to better splits at deeper levels
260260

261261
static int find_best_split(const int16_t *features, const int16_t *labels,
262-
EmlTreesModel *model, EmlTreesWorkspace *workspace,
262+
EmlExtraTreesModel *model, EmlExtraTreesWorkspace *workspace,
263263
int start, int end, int n_features_subset,
264264
int8_t *best_feature, int *best_threshold,
265265
float *best_improvement_out) {
@@ -372,7 +372,7 @@ static int find_best_split(const int16_t *features, const int16_t *labels,
372372
}
373373

374374
// ALSO: Ensure stopping criteria allow deep enough trees for XOR
375-
static int build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace,
375+
static int build_tree(EmlExtraTreesModel *model, EmlExtraTreesWorkspace *workspace,
376376
const int16_t *features, const int16_t *labels) {
377377

378378
int16_t tree_start = model->n_nodes_used;
@@ -554,7 +554,7 @@ static int build_tree(EmlTreesModel *model, EmlTreesWorkspace *workspace,
554554

555555

556556
// Initialize step-by-step training
557-
static int16_t eml_trees_train_init(EmlTreesModel *model, EmlTreesWorkspace *workspace,
557+
static int16_t eml_extratrees_train_init(EmlExtraTreesModel *model, EmlExtraTreesWorkspace *workspace,
558558
const int16_t *features, const int16_t *labels) {
559559
model->n_nodes_used = 0;
560560
model->n_trees_trained = 0;
@@ -599,7 +599,7 @@ static int16_t eml_trees_train_init(EmlTreesModel *model, EmlTreesWorkspace *wor
599599

600600
// Process one node in step-by-step training
601601
// Returns: 1=training complete, 0=more steps needed, -1=error
602-
static int16_t eml_trees_train_step(EmlTreesModel *model, EmlTreesWorkspace *workspace) {
602+
static int16_t eml_extratrees_train_step(EmlExtraTreesModel *model, EmlExtraTreesWorkspace *workspace) {
603603
if (workspace->train_state != 1) {
604604
return -1;
605605
}
@@ -766,14 +766,14 @@ static int16_t eml_trees_train_step(EmlTreesModel *model, EmlTreesWorkspace *wor
766766
}
767767

768768
// Train all trees at once (convenience wrapper)
769-
static int16_t eml_trees_train(EmlTreesModel *model, EmlTreesWorkspace *workspace,
769+
static int16_t eml_extratrees_train(EmlExtraTreesModel *model, EmlExtraTreesWorkspace *workspace,
770770
const int16_t *features, const int16_t *labels) {
771771

772-
int16_t result = eml_trees_train_init(model, workspace, features, labels);
772+
int16_t result = eml_extratrees_train_init(model, workspace, features, labels);
773773
if (result != 0) return result;
774774

775775
while (1) {
776-
result = eml_trees_train_step(model, workspace);
776+
result = eml_extratrees_train_step(model, workspace);
777777
if (result < 0) return result;
778778
if (result == 1) break; // done
779779
}

src/emlearn_extratrees/extratrees.c

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ void *memset(void *s, int c, size_t n) {
2424
// MicroPython type for ExtraTrees model
2525
typedef struct _mp_obj_extratrees_model_t {
2626
mp_obj_base_t base;
27-
EmlTreesModel model;
28-
EmlTreesWorkspace workspace;
27+
EmlExtraTreesModel model;
28+
EmlExtraTreesWorkspace workspace;
2929
mp_obj_t train_X_obj; // Reference to X Python object (prevents GC during step training)
3030
mp_obj_t train_y_obj; // Reference to y Python object (prevents GC during step training)
3131
} mp_obj_extratrees_model_t;
@@ -61,10 +61,10 @@ static mp_obj_t extratrees_model_new(size_t n_args, const mp_obj_t *args) {
6161
mp_obj_extratrees_model_t *o = \
6262
mp_obj_malloc(mp_obj_extratrees_model_t, (mp_obj_type_t *)&extratrees_model_type);
6363

64-
EmlTreesModel *model = &o->model;
65-
EmlTreesWorkspace *workspace = &o->workspace;
66-
memset(model, 0, sizeof(EmlTreesModel));
67-
memset(workspace, 0, sizeof(EmlTreesWorkspace));
64+
EmlExtraTreesModel *model = &o->model;
65+
EmlExtraTreesWorkspace *workspace = &o->workspace;
66+
memset(model, 0, sizeof(EmlExtraTreesModel));
67+
memset(workspace, 0, sizeof(EmlExtraTreesWorkspace));
6868

6969
// Configure model
7070
model->n_features = n_features;
@@ -85,7 +85,7 @@ static mp_obj_t extratrees_model_new(size_t n_args, const mp_obj_t *args) {
8585
model->config.use_global_feature_range = use_global_feature_range;
8686

8787
// Allocate model buffers
88-
model->nodes = m_new(EmlTreesNode, max_nodes);
88+
model->nodes = m_new(EmlExtraTreesNode, max_nodes);
8989
model->tree_starts = m_new(int16_t, n_trees);
9090

9191
// Allocate workspace buffers
@@ -110,7 +110,7 @@ static mp_obj_t extratrees_model_new(size_t n_args, const mp_obj_t *args) {
110110
o->train_y_obj = mp_const_none;
111111

112112
// Initialize nodes and tree starts
113-
memset(model->nodes, 0, sizeof(EmlTreesNode) * max_nodes);
113+
memset(model->nodes, 0, sizeof(EmlExtraTreesNode) * max_nodes);
114114
memset(model->tree_starts, 0, sizeof(int16_t) * n_trees);
115115

116116
return MP_OBJ_FROM_PTR(o);
@@ -121,11 +121,11 @@ static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_new_obj, 2, 12, extr
121121
// Delete an instance
122122
static mp_obj_t extratrees_model_del(mp_obj_t self_obj) {
123123
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj);
124-
EmlTreesModel *model = &o->model;
125-
EmlTreesWorkspace *workspace = &o->workspace;
124+
EmlExtraTreesModel *model = &o->model;
125+
EmlExtraTreesWorkspace *workspace = &o->workspace;
126126

127127
// Free allocated memory
128-
m_del(EmlTreesNode, model->nodes, model->max_nodes);
128+
m_del(EmlExtraTreesNode, model->nodes, model->max_nodes);
129129
m_del(int16_t, model->tree_starts, model->n_trees);
130130
m_del(uint16_t, workspace->sample_indices, model->max_samples);
131131
m_del(uint16_t, workspace->feature_indices, model->n_features);
@@ -155,8 +155,8 @@ static mp_obj_t extratrees_model_train(size_t n_args, const mp_obj_t *args) {
155155
}
156156

157157
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(args[0]);
158-
EmlTreesModel *model = &o->model;
159-
EmlTreesWorkspace *workspace = &o->workspace;
158+
EmlExtraTreesModel *model = &o->model;
159+
EmlExtraTreesWorkspace *workspace = &o->workspace;
160160

161161
// Extract X buffer
162162
mp_buffer_info_t X_bufinfo;
@@ -185,7 +185,7 @@ static mp_obj_t extratrees_model_train(size_t n_args, const mp_obj_t *args) {
185185
workspace->n_samples = n_samples;
186186

187187
// Pass buffer pointers directly (no copy needed)
188-
int16_t result = eml_trees_train(model, workspace, X, y);
188+
int16_t result = eml_extratrees_train(model, workspace, X, y);
189189

190190
if (result != 0) {
191191
mp_raise_ValueError(MP_ERROR_TEXT("Training failed"));
@@ -202,8 +202,8 @@ static mp_obj_t extratrees_model_train_init(size_t n_args, const mp_obj_t *args)
202202
}
203203

204204
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(args[0]);
205-
EmlTreesModel *model = &o->model;
206-
EmlTreesWorkspace *workspace = &o->workspace;
205+
EmlExtraTreesModel *model = &o->model;
206+
EmlExtraTreesWorkspace *workspace = &o->workspace;
207207

208208
// Extract X buffer
209209
mp_buffer_info_t X_bufinfo;
@@ -236,7 +236,7 @@ static mp_obj_t extratrees_model_train_init(size_t n_args, const mp_obj_t *args)
236236
o->train_y_obj = args[2];
237237

238238
// Pass buffer pointers directly (no copy needed)
239-
int16_t result = eml_trees_train_init(model, workspace, X, y);
239+
int16_t result = eml_extratrees_train_init(model, workspace, X, y);
240240
if (result != 0) {
241241
o->train_X_obj = mp_const_none;
242242
o->train_y_obj = mp_const_none;
@@ -251,10 +251,10 @@ static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_train_init_obj, 3, 3
251251
// Returns: 1=training complete, 0=more steps needed
252252
static mp_obj_t extratrees_model_train_step(mp_obj_t self_obj) {
253253
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj);
254-
EmlTreesModel *model = &o->model;
255-
EmlTreesWorkspace *workspace = &o->workspace;
254+
EmlExtraTreesModel *model = &o->model;
255+
EmlExtraTreesWorkspace *workspace = &o->workspace;
256256

257-
int16_t result = eml_trees_train_step(model, workspace);
257+
int16_t result = eml_extratrees_train_step(model, workspace);
258258
if (result < 0) {
259259
mp_raise_ValueError(MP_ERROR_TEXT("Train step failed"));
260260
}
@@ -276,8 +276,8 @@ static mp_obj_t extratrees_model_predict_proba(size_t n_args, const mp_obj_t *ar
276276
}
277277

278278
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(args[0]);
279-
EmlTreesModel *model = &o->model;
280-
EmlTreesWorkspace *workspace = &o->workspace;
279+
EmlExtraTreesModel *model = &o->model;
280+
EmlExtraTreesWorkspace *workspace = &o->workspace;
281281

282282
// Extract features buffer pointer and verify typecode
283283
mp_buffer_info_t features_bufinfo;
@@ -306,7 +306,7 @@ static mp_obj_t extratrees_model_predict_proba(size_t n_args, const mp_obj_t *ar
306306
}
307307

308308
// Make prediction using pre-allocated workspace arrays
309-
int16_t predicted_class = eml_trees_predict_proba(model, features, probabilities, workspace->votes);
309+
int16_t predicted_class = eml_extratrees_predict_proba(model, features, probabilities, workspace->votes);
310310

311311
return mp_obj_new_int(predicted_class);
312312
}
@@ -319,8 +319,8 @@ static mp_obj_t extratrees_model_predict(size_t n_args, const mp_obj_t *args) {
319319
}
320320

321321
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(args[0]);
322-
EmlTreesModel *model = &o->model;
323-
EmlTreesWorkspace *workspace = &o->workspace;
322+
EmlExtraTreesModel *model = &o->model;
323+
EmlExtraTreesWorkspace *workspace = &o->workspace;
324324

325325
// Extract features buffer pointer and verify typecode
326326
mp_buffer_info_t features_bufinfo;
@@ -336,7 +336,7 @@ static mp_obj_t extratrees_model_predict(size_t n_args, const mp_obj_t *args) {
336336
}
337337

338338
// Make prediction using pre-allocated workspace arrays
339-
int16_t predicted_class = eml_trees_predict_proba(model, features, workspace->probabilities, workspace->votes);
339+
int16_t predicted_class = eml_extratrees_predict_proba(model, features, workspace->probabilities, workspace->votes);
340340

341341
return mp_obj_new_int(predicted_class);
342342
}
@@ -345,7 +345,7 @@ static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(extratrees_model_predict_obj, 2, 2, e
345345
// Get number of features
346346
static mp_obj_t extratrees_model_get_n_features(mp_obj_t self_obj) {
347347
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj);
348-
EmlTreesModel *model = &o->model;
348+
EmlExtraTreesModel *model = &o->model;
349349

350350
return mp_obj_new_int(model->n_features);
351351
}
@@ -355,7 +355,7 @@ static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_features_obj, extratrees
355355
// Get number of classes
356356
static mp_obj_t extratrees_model_get_n_classes(mp_obj_t self_obj) {
357357
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj);
358-
EmlTreesModel *model = &o->model;
358+
EmlExtraTreesModel *model = &o->model;
359359

360360
return mp_obj_new_int(model->n_classes);
361361
}
@@ -365,7 +365,7 @@ static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_classes_obj, extratrees_
365365
// Get number of trees
366366
static mp_obj_t extratrees_model_get_n_trees(mp_obj_t self_obj) {
367367
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj);
368-
EmlTreesModel *model = &o->model;
368+
EmlExtraTreesModel *model = &o->model;
369369

370370
return mp_obj_new_int(model->n_trees);
371371
}
@@ -375,7 +375,7 @@ static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_trees_obj, extratrees_mo
375375
// Get number of nodes used
376376
static mp_obj_t extratrees_model_get_n_nodes_used(mp_obj_t self_obj) {
377377
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj);
378-
EmlTreesModel *model = &o->model;
378+
EmlExtraTreesModel *model = &o->model;
379379

380380
return mp_obj_new_int(model->n_nodes_used);
381381
}
@@ -384,7 +384,7 @@ static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_nodes_used_obj, extratre
384384
// Get number of trees trained
385385
static mp_obj_t extratrees_model_get_n_trees_trained(mp_obj_t self_obj) {
386386
mp_obj_extratrees_model_t *o = MP_OBJ_TO_PTR(self_obj);
387-
EmlTreesModel *model = &o->model;
387+
EmlExtraTreesModel *model = &o->model;
388388

389389
return mp_obj_new_int(model->n_trees_trained);
390390
}

0 commit comments

Comments
 (0)