@@ -269,7 +269,7 @@ static mp_obj_t extratrees_model_train_step(mp_obj_t self_obj) {
269269static MP_DEFINE_CONST_FUN_OBJ_1 (extratrees_model_train_step_obj , extratrees_model_train_step ) ;
270270
271271// Predict using the model (returns class probabilities)
272- static mp_obj_t extratrees_model_predict_proba (size_t n_args , const mp_obj_t * args ) {
272+ static mp_obj_t extratrees_model_predict (size_t n_args , const mp_obj_t * args ) {
273273 if (n_args != 3 ) {
274274 mp_raise_ValueError (MP_ERROR_TEXT ("Expected 3 arguments: self, features, probabilities" ));
275275 }
@@ -309,37 +309,9 @@ static mp_obj_t extratrees_model_predict_proba(size_t n_args, const mp_obj_t *ar
309309
310310 return mp_obj_new_int (predicted_class );
311311}
312- static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN (extratrees_model_predict_proba_obj , 3 , 3 , extratrees_model_predict_proba ) ;
313-
314- // Predict using the model (returns only class label)
315- static mp_obj_t extratrees_model_predict (size_t n_args , const mp_obj_t * args ) {
316- if (n_args != 2 ) {
317- mp_raise_ValueError (MP_ERROR_TEXT ("Expected 2 arguments: self, features" ));
318- }
312+ static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN (extratrees_model_predict_obj , 3 , 3 , extratrees_model_predict ) ;
319313
320- mp_obj_extratrees_model_t * o = MP_OBJ_TO_PTR (args [0 ]);
321- EmlExtraTreesModel * model = & o -> model ;
322- EmlExtraTreesWorkspace * workspace = & o -> workspace ;
323-
324- // Extract features buffer pointer and verify typecode
325- mp_buffer_info_t features_bufinfo ;
326- mp_get_buffer_raise (args [1 ], & features_bufinfo , MP_BUFFER_READ );
327- if (features_bufinfo .typecode != 'h' ) { // int16_t
328- mp_raise_ValueError (MP_ERROR_TEXT ("features expecting int16 array" ));
329- }
330- const int16_t * features = features_bufinfo .buf ;
331- const int n_features = features_bufinfo .len / sizeof (int16_t );
332-
333- if (n_features != model -> n_features ) {
334- mp_raise_ValueError (MP_ERROR_TEXT ("Feature count mismatch" ));
335- }
336-
337- // Make prediction using pre-allocated workspace arrays
338- int16_t predicted_class = eml_extratrees_predict_proba (model , features , workspace -> probabilities , workspace -> votes );
339-
340- return mp_obj_new_int (predicted_class );
341- }
342- static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN (extratrees_model_predict_obj , 2 , 2 , extratrees_model_predict ) ;
314+ // (old predict_proba removed - predict now handles both)
343315
344316// Get number of features
345317static mp_obj_t extratrees_model_get_n_features (mp_obj_t self_obj ) {
@@ -396,7 +368,7 @@ static MP_DEFINE_CONST_FUN_OBJ_1(extratrees_model_get_n_trees_trained_obj, extra
396368#if MICROPY_ENABLE_DYNRUNTIME
397369
398370// Module setup
399- mp_map_elem_t extratrees_model_locals_dict_table [12 ];
371+ mp_map_elem_t extratrees_model_locals_dict_table [11 ];
400372static MP_DEFINE_CONST_DICT (extratrees_model_locals_dict , extratrees_model_locals_dict_table ) ;
401373
402374// Module setup entrypoint
@@ -413,16 +385,15 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
413385 // methods
414386 extratrees_model_locals_dict_table [0 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_make_new ), MP_OBJ_FROM_PTR (& extratrees_model_make_new_obj ) };
415387 extratrees_model_locals_dict_table [1 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_predict ), MP_OBJ_FROM_PTR (& extratrees_model_predict_obj ) };
416- extratrees_model_locals_dict_table [2 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_predict_proba ), MP_OBJ_FROM_PTR (& extratrees_model_predict_proba_obj ) };
417- extratrees_model_locals_dict_table [3 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_train ), MP_OBJ_FROM_PTR (& extratrees_model_train_obj ) };
418- extratrees_model_locals_dict_table [4 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_train_init ), MP_OBJ_FROM_PTR (& extratrees_model_train_init_obj ) };
419- extratrees_model_locals_dict_table [5 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_train_step ), MP_OBJ_FROM_PTR (& extratrees_model_train_step_obj ) };
420- extratrees_model_locals_dict_table [6 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR___del__ ), MP_OBJ_FROM_PTR (& extratrees_model_del_obj ) };
421- extratrees_model_locals_dict_table [7 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_features ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_features_obj ) };
422- extratrees_model_locals_dict_table [8 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_classes ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_classes_obj ) };
423- extratrees_model_locals_dict_table [9 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_trees ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_trees_obj ) };
424- extratrees_model_locals_dict_table [10 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_nodes_used ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_nodes_used_obj ) };
425- extratrees_model_locals_dict_table [11 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_trees_trained ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_trees_trained_obj ) };
388+ extratrees_model_locals_dict_table [2 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_train ), MP_OBJ_FROM_PTR (& extratrees_model_train_obj ) };
389+ extratrees_model_locals_dict_table [3 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_train_init ), MP_OBJ_FROM_PTR (& extratrees_model_train_init_obj ) };
390+ extratrees_model_locals_dict_table [4 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_train_step ), MP_OBJ_FROM_PTR (& extratrees_model_train_step_obj ) };
391+ extratrees_model_locals_dict_table [5 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR___del__ ), MP_OBJ_FROM_PTR (& extratrees_model_del_obj ) };
392+ extratrees_model_locals_dict_table [6 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_features ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_features_obj ) };
393+ extratrees_model_locals_dict_table [7 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_classes ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_classes_obj ) };
394+ extratrees_model_locals_dict_table [8 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_trees ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_trees_obj ) };
395+ extratrees_model_locals_dict_table [9 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_nodes_used ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_nodes_used_obj ) };
396+ extratrees_model_locals_dict_table [10 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_get_n_trees_trained ), MP_OBJ_FROM_PTR (& extratrees_model_get_n_trees_trained_obj ) };
426397
427398 MP_OBJ_TYPE_SET_SLOT (& extratrees_model_type , locals_dict , (void * )& extratrees_model_locals_dict , 10 );
428399
@@ -436,7 +407,6 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
436407static const mp_rom_map_elem_t extratrees_model_locals_dict_table [] = {
437408 { MP_ROM_QSTR (MP_QSTR_make_new ), MP_ROM_PTR (& extratrees_model_make_new_obj ) },
438409 { MP_ROM_QSTR (MP_QSTR_predict ), MP_ROM_PTR (& extratrees_model_predict_obj ) },
439- { MP_ROM_QSTR (MP_QSTR_predict_proba ), MP_ROM_PTR (& extratrees_model_predict_proba_obj ) },
440410 { MP_ROM_QSTR (MP_QSTR_train ), MP_ROM_PTR (& extratrees_model_train_obj ) },
441411 { MP_ROM_QSTR (MP_QSTR_train_init ), MP_ROM_PTR (& extratrees_model_train_init_obj ) },
442412 { MP_ROM_QSTR (MP_QSTR_train_step ), MP_ROM_PTR (& extratrees_model_train_step_obj ) },
0 commit comments