Skip to content

Commit ca6cd9b

Browse files
committed
neighbors: Avoid kwargs for predict()
Not needed and was buggy with usermod
1 parent c0be4a0 commit ca6cd9b

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

src/emlearn_neighbors/neighbors.c

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,14 @@ static MP_DEFINE_CONST_FUN_OBJ_3(neighbors_model_get_item_obj, neighbors_model_g
154154

155155

156156
// Takes a integer array
157-
static mp_obj_t neighbors_model_predict(mp_obj_t *self_obj,
158-
size_t n_args, size_t n_kw, mp_obj_t *args) {
159-
// Check number of arguments is valid
160-
mp_arg_check_num(n_args, n_kw, 2, 2, false);
157+
static mp_obj_t neighbors_model_predict(mp_obj_t self_obj, mp_obj_t data_obj) {
161158

162-
mp_obj_neighbors_model_t *o = MP_OBJ_TO_PTR(args[0]);
159+
mp_obj_neighbors_model_t *o = MP_OBJ_TO_PTR(self_obj);
163160
EmlNeighborsModel *self = &o->model;
164161

165162
// Extract buffer pointer and verify typecode
166163
mp_buffer_info_t bufinfo;
167-
mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_RW);
164+
mp_get_buffer_raise(data_obj, &bufinfo, MP_BUFFER_RW);
168165
if (bufinfo.typecode != 'h') {
169166
mp_raise_ValueError(MP_ERROR_TEXT("expecting int16 array"));
170167
}
@@ -183,6 +180,7 @@ static mp_obj_t neighbors_model_predict(mp_obj_t *self_obj,
183180

184181
return mp_obj_new_int(out);
185182
}
183+
static MP_DEFINE_CONST_FUN_OBJ_2(neighbors_model_predict_obj, neighbors_model_predict);
186184

187185
// Access details about prediction result
188186
static mp_obj_t neighbors_model_get_result(mp_obj_t self_obj, mp_obj_t index_obj) {
@@ -225,7 +223,7 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
225223
neighbors_model_type.flags = MP_TYPE_FLAG_ITER_IS_CUSTOM;
226224
neighbors_model_type.name = MP_QSTR_emlneighbors;
227225
// methods
228-
neighbors_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_DYNRUNTIME_MAKE_FUNCTION(neighbors_model_predict) };
226+
neighbors_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_OBJ_FROM_PTR(&neighbors_model_predict_obj) };
229227
neighbors_model_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_additem), MP_OBJ_FROM_PTR(&neighbors_model_additem_obj) };
230228
neighbors_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&neighbors_model_del_obj) };
231229
neighbors_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_getresult), MP_OBJ_FROM_PTR(&neighbors_model_get_result_obj) };
@@ -242,7 +240,7 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
242240
// Define the class
243241
static const mp_rom_map_elem_t neighbors_model_locals_dict_table[] = {
244242

245-
{ MP_ROM_QSTR(MP_QSTR_predict), MP_ROM_PTR(&neighbors_model_predict) },
243+
{ MP_ROM_QSTR(MP_QSTR_predict), MP_ROM_PTR(&neighbors_model_predict_obj) },
246244
{ MP_ROM_QSTR(MP_QSTR_additem), MP_ROM_PTR(&neighbors_model_additem_obj) },
247245
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&neighbors_model_del_obj) },
248246
{ MP_ROM_QSTR(MP_QSTR_getresult), MP_ROM_PTR(&neighbors_model_get_result_obj) },

0 commit comments

Comments
 (0)