Skip to content

Commit 38e5eac

Browse files
committed
plsr: Support usage as external module
1 parent cd28fd1 commit 38e5eac

8 files changed

Lines changed: 136 additions & 28 deletions

File tree

src/emlearn_plsr/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ endif
2727
MOD = emlearn_plsr
2828

2929
# Source files (.c or .py)
30-
SRC = plsr.c plsr.py
30+
SRC = plsr.c emlearn_plsr.py
3131

3232
# Include to get the rules for compiling and linking the module
3333
include $(MPY_DIR)/py/dynruntime.mk
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
Training helper for EML PLS Regression MicroPython module
33
"""
44

5+
# When used as external C module, the .py is the top-level import,
6+
# and we need to merge the native module symbols at import time
7+
# When used as dynamic native modules (.mpy), .py and native code is merged at build time
8+
try:
9+
from emlearn_plsr_c import *
10+
except ImportError:
11+
pass
12+
513
log_prefix = 'emlearn_plsr:'
614

715
def fit(model, X_train, y_train,
@@ -72,5 +80,3 @@ def fit(model, X_train, y_train,
7280
print(log_prefix, f'Training incomplete after {total_iterations} iterations')
7381

7482
return total_iterations, final_metric
75-
76-

src/emlearn_plsr/micropython.cmake

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
add_library(usermod_emlearn_plsr INTERFACE)
2+
3+
target_sources(usermod_emlearn_plsr INTERFACE
4+
${CMAKE_CURRENT_LIST_DIR}/plsr.c
5+
)
6+
7+
target_include_directories(usermod_emlearn_plsr INTERFACE
8+
${CMAKE_CURRENT_LIST_DIR}
9+
)
10+
11+
target_link_libraries(usermod INTERFACE usermod_emlearn_plsr)

src/emlearn_plsr/micropython.mk

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
MOD_DIR := $(USERMOD_DIR)
2+
3+
# Add all C files to SRC_USERMOD.
4+
SRC_USERMOD_C += $(MOD_DIR)/plsr.c
5+
6+
# We can add our module folder to include paths if needed
7+
CFLAGS_USERMOD += -I$(MOD_DIR) -Wno-unused-function

src/emlearn_plsr/plsr.c

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
1-
// MicroPython native module wrapper for PLS Regression
1+
// MicroPython module wrapper for PLS Regression
2+
// Supports both native module (dynruntime) and external C module builds
3+
4+
#ifdef MICROPY_ENABLE_DYNRUNTIME
25
#include "py/dynruntime.h"
6+
#else
7+
#include "py/runtime.h"
8+
#endif
39

410
#include <string.h>
11+
#include <stdbool.h>
512

613
// NOTE: make sure we do not use sqrtf() wrapper which uses errno, does not work in native module
714
#if USE_IEEE_SQRTF
815
#define sqrtf(x) __ieee754_sqrtf(x)
916
#elif USE_BUILTIN_SQRTF
1017
#define sqrtf(x) __builtin_sqrtf(x)
11-
#else
1218
#endif
1319

1420
#include "eml_plsr.h"
1521

16-
// memset/memcpy for compatibility
22+
#ifdef MICROPY_ENABLE_DYNRUNTIME
23+
// memset/memcpy for compatibility
1724
#if !defined(__linux__)
1825
void *memcpy(void *dst, const void *src, size_t n) {
1926
return mp_fun_table.memmove_(dst, src, n);
@@ -22,27 +29,32 @@ void *memset(void *s, int c, size_t n) {
2229
return mp_fun_table.memset_(s, c, n);
2330
}
2431
#endif
25-
32+
#endif
2633

2734
// MicroPython type for PLSR model
2835
typedef struct _mp_obj_plsr_model_t {
2936
mp_obj_base_t base;
3037
eml_plsr_t model;
3138
uint8_t *memory; // Allocated memory block
39+
size_t memory_size;
3240
uint16_t n_samples;
3341
uint16_t n_features;
3442
uint16_t n_components;
3543
} mp_obj_plsr_model_t;
3644

45+
#if MICROPY_ENABLE_DYNRUNTIME
3746
mp_obj_full_type_t plsr_model_type;
47+
#else
48+
static const mp_obj_type_t plsr_model_type;
49+
#endif
3850

3951
// Create a new instance
4052
static mp_obj_t plsr_model_new(size_t n_args, const mp_obj_t *args) {
4153
// Args: n_samples, n_features, n_components
4254
if (n_args != 3) {
4355
mp_raise_ValueError(MP_ERROR_TEXT("Expected 3 arguments: n_samples, n_features, n_components"));
4456
}
45-
57+
4658
mp_int_t n_samples = mp_obj_get_int(args[0]);
4759
mp_int_t n_features = mp_obj_get_int(args[1]);
4860
mp_int_t n_components = mp_obj_get_int(args[2]);
@@ -56,7 +68,7 @@ static mp_obj_t plsr_model_new(size_t n_args, const mp_obj_t *args) {
5668
}
5769

5870
// Allocate space
59-
mp_obj_plsr_model_t *o = \
71+
mp_obj_plsr_model_t *o =
6072
mp_obj_malloc(mp_obj_plsr_model_t, (mp_obj_type_t *)&plsr_model_type);
6173

6274
o->n_samples = n_samples;
@@ -65,18 +77,19 @@ static mp_obj_t plsr_model_new(size_t n_args, const mp_obj_t *args) {
6577

6678
// Calculate and allocate memory
6779
size_t memory_size = eml_plsr_get_memory_size(n_samples, n_features, n_components);
68-
o->memory = (uint8_t *)m_malloc(memory_size);
69-
80+
o->memory_size = memory_size;
81+
o->memory = m_new(uint8_t, memory_size);
82+
7083
if (!o->memory) {
7184
mp_raise_ValueError(MP_ERROR_TEXT("Failed to allocate PLSR memory"));
7285
}
7386

7487
// Initialize model
75-
EmlError err = eml_plsr_init(&o->model, n_samples, n_features, n_components,
88+
EmlError err = eml_plsr_init(&o->model, n_samples, n_features, n_components,
7689
o->memory, memory_size);
77-
90+
7891
if (err != EmlOk) {
79-
m_free(o->memory);
92+
m_del(uint8_t, o->memory, memory_size);
8093
mp_raise_ValueError(MP_ERROR_TEXT("Failed to initialize PLSR model"));
8194
}
8295

@@ -90,22 +103,21 @@ static mp_obj_t plsr_model_del(mp_obj_t self_obj) {
90103

91104
// Free allocated memory
92105
if (o->memory) {
93-
m_free(o->memory);
106+
m_del(uint8_t, o->memory, o->memory_size);
94107
o->memory = NULL;
95108
}
96109

97110
return mp_const_none;
98111
}
99112
static MP_DEFINE_CONST_FUN_OBJ_1(plsr_model_del_obj, plsr_model_del);
100113

101-
102114
// Start iterative fitting
103115
static mp_obj_t plsr_model_fit_start(size_t n_args, const mp_obj_t *args) {
104116
// Args: self, X, y
105117
if (n_args != 3) {
106118
mp_raise_ValueError(MP_ERROR_TEXT("Expected 3 arguments: self, X, y"));
107119
}
108-
120+
109121
mp_obj_plsr_model_t *o = MP_OBJ_TO_PTR(args[0]);
110122
eml_plsr_t *self = &o->model;
111123

@@ -152,7 +164,7 @@ static mp_obj_t plsr_model_step(size_t n_args, const mp_obj_t *args) {
152164
if (n_args < 1 || n_args > 2) {
153165
mp_raise_ValueError(MP_ERROR_TEXT("Expected 1-2 arguments: self, [tolerance]"));
154166
}
155-
167+
156168
mp_obj_plsr_model_t *o = MP_OBJ_TO_PTR(args[0]);
157169
eml_plsr_t *self = &o->model;
158170

@@ -211,13 +223,14 @@ static mp_obj_t plsr_model_is_complete(mp_obj_t self_obj) {
211223
static MP_DEFINE_CONST_FUN_OBJ_1(plsr_model_is_complete_obj, plsr_model_is_complete);
212224

213225
// Predict using the model
214-
static mp_obj_t plsr_model_predict(mp_obj_fun_bc_t *self_obj,
215-
size_t n_args, size_t n_kw, mp_obj_t *args) {
216-
// Check number of arguments is valid
217-
mp_arg_check_num(n_args, n_kw, 2, 2, false);
226+
static mp_obj_t plsr_model_predict(size_t n_args, const mp_obj_t *args) {
227+
// Args: self, features
228+
if (n_args != 2) {
229+
mp_raise_ValueError(MP_ERROR_TEXT("Expected 2 arguments: self, features"));
230+
}
218231

219232
mp_obj_plsr_model_t *o = MP_OBJ_TO_PTR(args[0]);
220-
eml_plsr_t *self = &o->model;
233+
eml_plsr_t *self = &o->model;
221234

222235
// Extract buffer pointer and verify typecode
223236
mp_buffer_info_t bufinfo;
@@ -242,6 +255,7 @@ static mp_obj_t plsr_model_predict(mp_obj_fun_bc_t *self_obj,
242255

243256
return mp_obj_new_float_from_f(prediction);
244257
}
258+
static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(plsr_model_predict_obj, 2, 2, plsr_model_predict);
245259

246260
// Get convergence metric
247261
static mp_obj_t plsr_model_get_convergence_metric(mp_obj_t self_obj) {
@@ -252,7 +266,13 @@ static mp_obj_t plsr_model_get_convergence_metric(mp_obj_t self_obj) {
252266
}
253267
static MP_DEFINE_CONST_FUN_OBJ_1(plsr_model_get_convergence_metric_obj, plsr_model_get_convergence_metric);
254268

255-
// Module setup
269+
// ============================================================================
270+
// Build-type specific module registration
271+
// ============================================================================
272+
273+
#if MICROPY_ENABLE_DYNRUNTIME
274+
275+
// Forward declaration for locals dict
256276
mp_map_elem_t plsr_model_locals_dict_table[8];
257277
static MP_DEFINE_CONST_DICT(plsr_model_locals_dict, plsr_model_locals_dict_table);
258278

@@ -266,9 +286,9 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
266286
plsr_model_type.base.type = (void*)&mp_fun_table.type_type;
267287
plsr_model_type.flags = MP_TYPE_FLAG_ITER_IS_CUSTOM;
268288
plsr_model_type.name = MP_QSTR_plsr;
269-
289+
270290
// methods
271-
plsr_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_DYNRUNTIME_MAKE_FUNCTION(plsr_model_predict) };
291+
plsr_model_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_predict), MP_OBJ_FROM_PTR(&plsr_model_predict_obj) };
272292
plsr_model_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&plsr_model_del_obj) };
273293
plsr_model_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_fit_start), MP_OBJ_FROM_PTR(&plsr_model_fit_start_obj) };
274294
plsr_model_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_step), MP_OBJ_FROM_PTR(&plsr_model_step_obj) };
@@ -277,9 +297,46 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
277297
plsr_model_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_is_complete), MP_OBJ_FROM_PTR(&plsr_model_is_complete_obj) };
278298
plsr_model_locals_dict_table[7] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_get_convergence_metric), MP_OBJ_FROM_PTR(&plsr_model_get_convergence_metric_obj) };
279299

280-
281300
MP_OBJ_TYPE_SET_SLOT(&plsr_model_type, locals_dict, (void*)&plsr_model_locals_dict, 8);
282301

283302
// This must be last, it restores the globals dict
284303
MP_DYNRUNTIME_INIT_EXIT
285304
}
305+
306+
#else
307+
308+
// External C module build
309+
310+
static const mp_rom_map_elem_t plsr_model_locals_dict_table[] = {
311+
{ MP_ROM_QSTR(MP_QSTR_predict), MP_ROM_PTR(&plsr_model_predict_obj) },
312+
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&plsr_model_del_obj) },
313+
{ MP_ROM_QSTR(MP_QSTR_fit_start), MP_ROM_PTR(&plsr_model_fit_start_obj) },
314+
{ MP_ROM_QSTR(MP_QSTR_step), MP_ROM_PTR(&plsr_model_step_obj) },
315+
{ MP_ROM_QSTR(MP_QSTR_finalize_component), MP_ROM_PTR(&plsr_model_finalize_component_obj) },
316+
{ MP_ROM_QSTR(MP_QSTR_is_converged), MP_ROM_PTR(&plsr_model_is_converged_obj) },
317+
{ MP_ROM_QSTR(MP_QSTR_is_complete), MP_ROM_PTR(&plsr_model_is_complete_obj) },
318+
{ MP_ROM_QSTR(MP_QSTR_get_convergence_metric), MP_ROM_PTR(&plsr_model_get_convergence_metric_obj) },
319+
};
320+
static MP_DEFINE_CONST_DICT(plsr_model_locals_dict, plsr_model_locals_dict_table);
321+
322+
static MP_DEFINE_CONST_OBJ_TYPE(
323+
plsr_model_type,
324+
MP_QSTR_plsr,
325+
MP_TYPE_FLAG_ITER_IS_CUSTOM,
326+
make_new, plsr_model_new,
327+
locals_dict, &plsr_model_locals_dict
328+
);
329+
330+
static const mp_rom_map_elem_t plsr_globals_table[] = {
331+
{ MP_ROM_QSTR(MP_QSTR_new), MP_ROM_PTR(&plsr_model_new_obj) },
332+
{ MP_ROM_QSTR(MP_QSTR_plsr), MP_ROM_PTR(&plsr_model_type) },
333+
};
334+
static MP_DEFINE_CONST_DICT(plsr_globals, plsr_globals_table);
335+
336+
const mp_obj_module_t plsr_cmodule = {
337+
.base = { &mp_type_module },
338+
.globals = (mp_obj_dict_t *)&plsr_globals,
339+
};
340+
MP_REGISTER_MODULE(MP_QSTR_emlearn_plsr_c, plsr_cmodule);
341+
342+
#endif

src/manifest_unix.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
module("emlearn_linreg.py", base_path='./emlearn_linreg')
99
module("emlearn_logreg.py", base_path='./emlearn_logreg')
1010
module("emlearn_extratrees.py", base_path='./emlearn_extratrees')
11+
module("emlearn_plsr.py", base_path='./emlearn_plsr')
1112

1213
#include("$(PORT_DIR)/boards/manifest.py")

src/micropython.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ include(${CMAKE_CURRENT_LIST_DIR}/emlearn_iir/micropython.cmake)
55
include(${CMAKE_CURRENT_LIST_DIR}/emlearn_neighbors/micropython.cmake)
66
include(${CMAKE_CURRENT_LIST_DIR}/emlearn_arrayutils/micropython.cmake)
77
include(${CMAKE_CURRENT_LIST_DIR}/tinymaix_cnn/micropython.cmake)
8+
include(${CMAKE_CURRENT_LIST_DIR}/emlearn_plsr/micropython.cmake)

tests/test_plsr.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,30 @@ def assert_equal(a, b, name="value"):
2222
print(f" ✓ {name}: {a} == {b}")
2323

2424

25+
def test_model_creation():
26+
"""Test model creation and basic properties"""
27+
print("\n=== Test: Model Creation ===")
28+
29+
model = emlearn_plsr.new(8, 3, 2)
30+
assert_true(model is not None, "Model created")
31+
32+
# Check not complete before training
33+
assert_true(not model.is_complete(), "Not complete initially")
34+
35+
# Check invalid dimensions raise error
36+
try:
37+
emlearn_plsr.new(0, 3, 2) # zero samples
38+
assert_true(False, "Should have raised error for zero samples")
39+
except ValueError:
40+
print(" ✓ Caught zero samples")
41+
42+
try:
43+
emlearn_plsr.new(8, 0, 2) # zero features
44+
assert_true(False, "Should have raised error for zero features")
45+
except ValueError:
46+
print(" ✓ Caught zero features")
47+
48+
2549
def test_simple_training():
2650
"""Test basic training and prediction"""
2751
print("\n=== Test: Simple Training ===")
@@ -241,4 +265,5 @@ def run_all_tests():
241265

242266

243267
if __name__ == '__main__':
244-
exit(run_all_tests())
268+
import sys
269+
sys.exit(run_all_tests())

0 commit comments

Comments
 (0)