Skip to content

Commit 04e8ff1

Browse files
committed
cnn: Try to support both int8 and fp32 for external
A bit tricky since one needs to avoid duplicate symbols, and the MicroPython build system is quite opinionated
1 parent 26945fd commit 04e8ff1

10 files changed

Lines changed: 85 additions & 29 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
/*
2+
* emlearn_cnn_fp32 wrapper
3+
* This file sets CONFIG_FP32 and includes the main mod_cnn.c
4+
*/
5+
#define CONFIG_FP32
6+
#include "../tinymaix_cnn/mod_cnn.c"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# emlearn_cnn_fp32 - wrapper for frozen unix build
22
# Imports the native C module compiled with fp32 configuration
33
try:
4-
from tinymaix_cnn_fp32_native import *
4+
from emlearn_cnn_fp32_native import *
55
except ImportError:
66
pass
Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
# emlearn_cnn_fp32 wrapper for Unix port
2-
# Compiles mod_cnn.c with fp32 configuration
2+
# This wrapper sets CONFIG_FP32 before including mod_cnn.c
33

44
CNN_SRC := $(USERMOD_DIR)/../tinymaix_cnn
55

6-
# Add C source file from tinymaix_cnn
7-
SRC_USERMOD_C += $(CNN_SRC)/mod_cnn.c
6+
# Add wrapper C file which defines CONFIG_FP32
7+
SRC_USERMOD_C += $(USERMOD_DIR)/emlearn_cnn_fp32.c
88

9-
# Include paths
9+
# Include paths - need to include tm_port.h location
1010
CFLAGS_USERMOD += -I$(CNN_SRC)/fp32
1111
CFLAGS_USERMOD += -I$(CNN_SRC)/../../dependencies/TinyMaix/include
1212
CFLAGS_USERMOD += -I$(CNN_SRC)/../../dependencies/TinyMaix/src
1313

14-
# Compile flags
14+
# Compile flags to suppress TinyMaix warnings
1515
CFLAGS_USERMOD += -Wno-error=unused-variable -Wno-error=multichar -Wdouble-promotion
16-
17-
# Define CONFIG_FP32 for the C code
18-
CFLAGS_USERMOD += -DCONFIG_FP32
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
/*
2+
* emlearn_cnn_int8 wrapper
3+
* This file sets CONFIG_INT8 and includes the main mod_cnn.c
4+
*/
5+
#define CONFIG_INT8
6+
#include "../tinymaix_cnn/mod_cnn.c"
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# emlearn_cnn_int8 - wrapper for frozen unix build
2+
# Imports the native C module compiled with int8 configuration
3+
try:
4+
from emlearn_cnn_int8_native import *
5+
except ImportError:
6+
pass
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# emlearn_cnn_int8 wrapper for Unix port
2+
# This wrapper sets CONFIG_INT8 before including mod_cnn.c
3+
4+
CNN_SRC := $(USERMOD_DIR)/../tinymaix_cnn
5+
6+
# Add wrapper C file which defines CONFIG_INT8
7+
SRC_USERMOD_C += $(USERMOD_DIR)/emlearn_cnn_int8.c
8+
9+
# Include paths - need to include tm_port.h location
10+
CFLAGS_USERMOD += -I$(CNN_SRC)/int8
11+
CFLAGS_USERMOD += -I$(CNN_SRC)/../../dependencies/TinyMaix/include
12+
CFLAGS_USERMOD += -I$(CNN_SRC)/../../dependencies/TinyMaix/src
13+
14+
# Compile flags to suppress TinyMaix warnings
15+
CFLAGS_USERMOD += -Wno-error=unused-variable -Wno-error=multichar -Wdouble-promotion

src/manifest_unix.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Ref https://docs.micropython.org/en/latest/reference/manifest.html
55
module("emlearn_trees.py", base_path='./emlearn_trees')
66
module("emlearn_kmeans.py", base_path='./emlearn_kmeans')
7+
module("emlearn_cnn_int8.py", base_path='./emlearn_cnn_int8')
78
module("emlearn_cnn_fp32.py", base_path='./emlearn_cnn_fp32')
89
module("emlearn_fft.py", base_path='./emlearn_fft')
910
module("emlearn_linreg.py", base_path='./emlearn_linreg')

src/tinymaix_cnn/fp32/tm_port.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,16 @@ limitations under the License.
4444
#define TM_WEAK __attribute__((weak))
4545

4646
// Disable "static" (non-const) globals, since they are not supported by MicroPython mpy_ld.py
47-
#define TM_STATIC
47+
// But when building multiple variants, we need static to avoid duplicate definitions
48+
#ifdef CONFIG_FP32
49+
#define TM_STATIC static
50+
#else
51+
#define TM_STATIC
52+
#endif
4853

4954
// Use MicroPython for dynamic allocation
5055
#define tm_malloc(x) m_malloc(x)
51-
#define tm_free(x) mod_cnn_free(x)
56+
#define tm_free(x) CNN_FREE(x)
5257

5358
// FIXME: set theese to use MicroPython primitives
5459

src/tinymaix_cnn/int8/tm_port.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,16 @@ limitations under the License.
4444
#define TM_WEAK __attribute__((weak))
4545

4646
// Disable "static" (non-const) globals, since they are not supported by MicroPython mpy_ld.py
47-
#define TM_STATIC
47+
// But when building multiple variants, we need static to avoid duplicate definitions
48+
#ifdef CONFIG_INT8
49+
#define TM_STATIC static
50+
#else
51+
#define TM_STATIC
52+
#endif
4853

4954
// Use MicroPython for dynamic allocation
5055
#define tm_malloc(x) m_malloc(x)
51-
#define tm_free(x) mod_cnn_free(x)
56+
#define tm_free(x) CNN_FREE(x)
5257

5358
// FIXME: set theese to use MicroPython primitives
5459

src/tinymaix_cnn/mod_cnn.c

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,23 @@
55
#include "py/runtime.h"
66
#endif
77

8+
// Define unique symbol names based on CONFIG
9+
#ifdef CONFIG_FP32
10+
#define CNN_TYPE mod_cnn_fp32_type
11+
#define CNN_CMODULE mod_cnn_fp32_cmodule
12+
#define CNN_FREE mod_cnn_fp32_free
13+
#elif defined(CONFIG_INT8)
14+
#define CNN_TYPE mod_cnn_int8_type
15+
#define CNN_CMODULE mod_cnn_int8_cmodule
16+
#define CNN_FREE mod_cnn_int8_free
17+
#else
18+
#define CNN_TYPE mod_cnn_int8_type
19+
#define CNN_CMODULE mod_cnn_int8_cmodule
20+
#define CNN_FREE mod_cnn_int8_free
21+
#endif
822

9-
void mod_cnn_free(void *ptr);
23+
// Forward declaration for tm_port.h
24+
void CNN_FREE(void *ptr);
1025

1126
// TinyMaix config
1227
#include "./tm_port.h"
@@ -81,18 +96,18 @@ typedef struct _mp_obj_mod_cnn_t {
8196
} mp_obj_mod_cnn_t;
8297

8398
#if MICROPY_ENABLE_DYNRUNTIME
84-
mp_obj_full_type_t mod_cnn_type;
99+
mp_obj_full_type_t CNN_TYPE;
85100
#else
86-
static const mp_obj_type_t mod_cnn_type;
101+
static const mp_obj_type_t CNN_TYPE;
87102
#endif
88103

89104

90-
void mod_cnn_free(void *ptr)
105+
void CNN_FREE(void *ptr)
91106
{
92107
#if MICROPY_ENABLE_DYNRUNTIME
93108
return m_free(ptr);
94109
#else
95-
return m_del(void *, ptr, 0); // XXX: not sure if safe
110+
return m_del(void *, ptr, 0);
96111
#endif
97112
}
98113

@@ -116,7 +131,7 @@ static mp_obj_t mod_cnn_new(mp_obj_t model_data_obj) {
116131
const int model_data_length = bufinfo.len / sizeof(*model_data_buffer);
117132

118133
// Construct object
119-
mp_obj_mod_cnn_t *o = mp_obj_malloc(mp_obj_mod_cnn_t, (mp_obj_type_t *)&mod_cnn_type);
134+
mp_obj_mod_cnn_t *o = mp_obj_malloc(mp_obj_mod_cnn_t, (mp_obj_type_t *)&CNN_TYPE);
120135
tm_mdl_t *model = &o->model;
121136

122137
// Copy the model data
@@ -283,15 +298,15 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
283298

284299
mp_store_global(MP_QSTR_new, MP_OBJ_FROM_PTR(&mod_cnn_new_obj));
285300

286-
mod_cnn_type.base.type = (void*)&mp_fun_table.type_type;
287-
mod_cnn_type.flags = MP_TYPE_FLAG_ITER_IS_CUSTOM;
288-
mod_cnn_type.name = MP_QSTR_tinymaixcnn;
301+
CNN_TYPE.base.type = (void*)&mp_fun_table.type_type;
302+
CNN_TYPE.flags = MP_TYPE_FLAG_ITER_IS_CUSTOM;
303+
CNN_TYPE.name = MP_QSTR_tinymaixcnn;
289304
// methods
290305
mod_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_run), MP_OBJ_FROM_PTR(&mod_cnn_run_obj) };
291306
mod_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&mod_cnn_del_obj) };
292307
mod_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_output_dimensions), MP_OBJ_FROM_PTR(&mod_cnn_output_dimensions_obj) };
293308

294-
MP_OBJ_TYPE_SET_SLOT(&mod_cnn_type, locals_dict, (void*)&mod_locals_dict, 2);
309+
MP_OBJ_TYPE_SET_SLOT(&CNN_TYPE, locals_dict, (void*)&mod_locals_dict, 2);
295310

296311
// This must be last, it restores the globals dict
297312
MP_DYNRUNTIME_INIT_EXIT
@@ -308,7 +323,7 @@ static MP_DEFINE_CONST_DICT(mod_cnn_locals_dict, mod_cnn_locals_dict_table);
308323

309324

310325
static MP_DEFINE_CONST_OBJ_TYPE(
311-
mod_cnn_type,
326+
CNN_TYPE,
312327
MP_QSTR_tinymaix_cnn,
313328
MP_TYPE_FLAG_NONE,
314329
locals_dict, &mod_cnn_locals_dict
@@ -320,17 +335,17 @@ static const mp_rom_map_elem_t mod_cnn_globals_table[] = {
320335
};
321336
static MP_DEFINE_CONST_DICT(mod_cnn_globals, mod_cnn_globals_table);
322337

323-
const mp_obj_module_t mod_cnn_cmodule = {
338+
const mp_obj_module_t CNN_CMODULE = {
324339
.base = { &mp_type_module },
325340
.globals = (mp_obj_dict_t *)&mod_cnn_globals,
326341
};
327342

328343
// Module name depends on CONFIG
329344
#ifdef CONFIG_FP32
330-
MP_REGISTER_MODULE(MP_QSTR_tinymaix_cnn_fp32_native, mod_cnn_cmodule);
345+
MP_REGISTER_MODULE(MP_QSTR_emlearn_cnn_fp32_native, CNN_CMODULE);
346+
#elif defined(CONFIG_INT8)
347+
MP_REGISTER_MODULE(MP_QSTR_emlearn_cnn_int8_native, CNN_CMODULE);
331348
#else
332-
MP_REGISTER_MODULE(MP_QSTR_emlearn_cnn_int8, mod_cnn_cmodule);
349+
MP_REGISTER_MODULE(MP_QSTR_emlearn_cnn_int8_native, CNN_CMODULE);
333350
#endif
334351
#endif
335-
336-

0 commit comments

Comments
 (0)