@@ -313,36 +313,53 @@ static mp_obj_t logreg_model_set_weights(mp_obj_t self_obj, mp_obj_t weights_obj
313313}
314314static MP_DEFINE_CONST_FUN_OBJ_2 (logreg_model_set_weights_obj , logreg_model_set_weights ) ;
315315
316- static mp_obj_t logreg_model_get_bias (mp_obj_t self_obj ) {
316+ static mp_obj_t logreg_model_get_bias (mp_obj_t self_obj , mp_obj_t out_obj ) {
317317 mp_obj_logreg_model_t * o = MP_OBJ_TO_PTR (self_obj );
318318 logreg_model_t * self = & o -> model ;
319319
320- mp_obj_t bias_obj = mp_obj_new_bytearray_by_ref (sizeof (float ) * self -> n_classes , self -> biases );
321- #if MICROPY_ENABLE_DYNRUNTIME
322- mp_obj_array_t * arr = MP_OBJ_TO_PTR (bias_obj );
323- arr -> typecode = 'f' ;
324- #else
325- ((mp_obj_base_t * )MP_OBJ_TO_PTR (bias_obj ))-> type = & mp_type_bytearray ;
326- #endif
327- return bias_obj ;
320+ mp_buffer_info_t bufinfo ;
321+ mp_get_buffer_raise (out_obj , & bufinfo , MP_BUFFER_WRITE );
322+ if (bufinfo .typecode != 'f' ) {
323+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting float32 array" ));
324+ }
325+ const size_t n_out = bufinfo .len / sizeof (float );
326+ if (n_out != self -> n_classes ) {
327+ mp_raise_ValueError (MP_ERROR_TEXT ("bias buffer wrong length" ));
328+ }
329+
330+ float * out = bufinfo .buf ;
331+ memcpy (out , self -> biases , sizeof (float ) * self -> n_classes );
332+
333+ return mp_const_none ;
328334}
329- static MP_DEFINE_CONST_FUN_OBJ_1 (logreg_model_get_bias_obj , logreg_model_get_bias ) ;
335+ static MP_DEFINE_CONST_FUN_OBJ_2 (logreg_model_get_bias_obj , logreg_model_get_bias ) ;
330336
331337static mp_obj_t logreg_model_set_bias (mp_obj_t self_obj , mp_obj_t bias_obj ) {
332338 mp_obj_logreg_model_t * o = MP_OBJ_TO_PTR (self_obj );
333339 logreg_model_t * self = & o -> model ;
334340
335341 mp_buffer_info_t bufinfo ;
336- mp_get_buffer_raise (bias_obj , & bufinfo , MP_BUFFER_READ );
337- if (bufinfo .typecode != 'f' ) {
338- mp_raise_ValueError (MP_ERROR_TEXT ("expecting float32 array" ));
339- }
340- const size_t n_bias = bufinfo .len / sizeof (float );
341- if (n_bias != self -> n_classes ) {
342- mp_raise_ValueError (MP_ERROR_TEXT ("bias array size mismatch" ));
342+ if (mp_get_buffer (bias_obj , & bufinfo , MP_BUFFER_READ )) {
343+ if (bufinfo .typecode != 'f' ) {
344+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting float32 array" ));
345+ }
346+ const size_t n_bias = bufinfo .len / sizeof (float );
347+ if (n_bias != self -> n_classes ) {
348+ mp_raise_ValueError (MP_ERROR_TEXT ("bias array size mismatch" ));
349+ }
350+ const float * biases = bufinfo .buf ;
351+ memcpy (self -> biases , biases , sizeof (float ) * self -> n_classes );
352+ } else {
353+ size_t len ;
354+ mp_obj_t * items ;
355+ mp_obj_get_array (bias_obj , & len , & items );
356+ if (len != self -> n_classes ) {
357+ mp_raise_ValueError (MP_ERROR_TEXT ("bias array size mismatch" ));
358+ }
359+ for (size_t i = 0 ; i < len ; i ++ ) {
360+ self -> biases [i ] = mp_obj_get_float_to_f (items [i ]);
361+ }
343362 }
344- const float * biases = bufinfo .buf ;
345- memcpy (self -> biases , biases , sizeof (float ) * self -> n_classes );
346363
347364 return mp_const_none ;
348365}
0 commit comments