Skip to content

Commit 74a375b

Browse files
committed
logreg: Fix get_bias
1 parent 9b50539 commit 74a375b

2 files changed

Lines changed: 37 additions & 23 deletions

File tree

src/emlearn_logreg/logreg.c

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -313,36 +313,53 @@ static mp_obj_t logreg_model_set_weights(mp_obj_t self_obj, mp_obj_t weights_obj
313313
}
314314
static MP_DEFINE_CONST_FUN_OBJ_2(logreg_model_set_weights_obj, logreg_model_set_weights);
315315

316-
static mp_obj_t logreg_model_get_bias(mp_obj_t self_obj) {
316+
static mp_obj_t logreg_model_get_bias(mp_obj_t self_obj, mp_obj_t out_obj) {
317317
mp_obj_logreg_model_t *o = MP_OBJ_TO_PTR(self_obj);
318318
logreg_model_t *self = &o->model;
319319

320-
mp_obj_t bias_obj = mp_obj_new_bytearray_by_ref(sizeof(float) * self->n_classes, self->biases);
321-
#if MICROPY_ENABLE_DYNRUNTIME
322-
mp_obj_array_t *arr = MP_OBJ_TO_PTR(bias_obj);
323-
arr->typecode = 'f';
324-
#else
325-
((mp_obj_base_t *)MP_OBJ_TO_PTR(bias_obj))->type = &mp_type_bytearray;
326-
#endif
327-
return bias_obj;
320+
mp_buffer_info_t bufinfo;
321+
mp_get_buffer_raise(out_obj, &bufinfo, MP_BUFFER_WRITE);
322+
if (bufinfo.typecode != 'f') {
323+
mp_raise_ValueError(MP_ERROR_TEXT("expecting float32 array"));
324+
}
325+
const size_t n_out = bufinfo.len / sizeof(float);
326+
if (n_out != self->n_classes) {
327+
mp_raise_ValueError(MP_ERROR_TEXT("bias buffer wrong length"));
328+
}
329+
330+
float *out = bufinfo.buf;
331+
memcpy(out, self->biases, sizeof(float) * self->n_classes);
332+
333+
return mp_const_none;
328334
}
329-
static MP_DEFINE_CONST_FUN_OBJ_1(logreg_model_get_bias_obj, logreg_model_get_bias);
335+
static MP_DEFINE_CONST_FUN_OBJ_2(logreg_model_get_bias_obj, logreg_model_get_bias);
330336

331337
static mp_obj_t logreg_model_set_bias(mp_obj_t self_obj, mp_obj_t bias_obj) {
332338
mp_obj_logreg_model_t *o = MP_OBJ_TO_PTR(self_obj);
333339
logreg_model_t *self = &o->model;
334340

335341
mp_buffer_info_t bufinfo;
336-
mp_get_buffer_raise(bias_obj, &bufinfo, MP_BUFFER_READ);
337-
if (bufinfo.typecode != 'f') {
338-
mp_raise_ValueError(MP_ERROR_TEXT("expecting float32 array"));
339-
}
340-
const size_t n_bias = bufinfo.len / sizeof(float);
341-
if (n_bias != self->n_classes) {
342-
mp_raise_ValueError(MP_ERROR_TEXT("bias array size mismatch"));
342+
if (mp_get_buffer(bias_obj, &bufinfo, MP_BUFFER_READ)) {
343+
if (bufinfo.typecode != 'f') {
344+
mp_raise_ValueError(MP_ERROR_TEXT("expecting float32 array"));
345+
}
346+
const size_t n_bias = bufinfo.len / sizeof(float);
347+
if (n_bias != self->n_classes) {
348+
mp_raise_ValueError(MP_ERROR_TEXT("bias array size mismatch"));
349+
}
350+
const float *biases = bufinfo.buf;
351+
memcpy(self->biases, biases, sizeof(float) * self->n_classes);
352+
} else {
353+
size_t len;
354+
mp_obj_t *items;
355+
mp_obj_get_array(bias_obj, &len, &items);
356+
if (len != self->n_classes) {
357+
mp_raise_ValueError(MP_ERROR_TEXT("bias array size mismatch"));
358+
}
359+
for (size_t i = 0; i < len; i++) {
360+
self->biases[i] = mp_obj_get_float_to_f(items[i]);
361+
}
343362
}
344-
const float *biases = bufinfo.buf;
345-
memcpy(self->biases, biases, sizeof(float) * self->n_classes);
346363

347364
return mp_const_none;
348365
}

tests/test_logreg.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,10 @@ def read_weights(model):
116116

117117
def read_bias(model):
118118
n_classes = model.get_n_classes()
119-
src = model.get_bias()
120119
out = array.array('f', [0.0] * n_classes)
121-
for idx in range(n_classes):
122-
out[idx] = src[idx]
120+
model.get_bias(out)
123121
return out
124122

125-
126123
def alloc_predict_buffers(model):
127124
n_classes = model.get_n_classes()
128125
logits = array.array('f', [0.0] * n_classes)

0 commit comments

Comments
 (0)