Skip to content

Commit a308dd3

Browse files
committed
Added a way to dynamically register a new embedding engine
1 parent aa64257 commit a308dd3

File tree

5 files changed

+349
-30
lines changed

5 files changed

+349
-30
lines changed

src/dbmem-embed.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,8 @@ dbmem_remote_engine_t *dbmem_remote_engine_init (void *ctx, const char *provider
3131
int dbmem_remote_compute_embedding (dbmem_remote_engine_t *engine, const char *text, int text_len, embedding_result_t *result);
3232
void dbmem_remote_engine_free (dbmem_remote_engine_t *engine);
3333

34+
// Custom provider (always available, defined in sqlite-memory.c)
35+
typedef struct dbmem_context dbmem_context;
36+
int dbmem_context_custom_compute (dbmem_context *ctx, const char *text, int text_len, embedding_result_t *result);
37+
3438
#endif

src/dbmem-search.c

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -603,28 +603,36 @@ static int vMemorySearchCursorFilter (sqlite3_vtab_cursor *cur, int idxNum, cons
603603

604604
// compute embedding
605605
embedding_result_t result = {0};
606-
606+
607607
rc = SQLITE_MISUSE;
608+
if (dbmem_context_is_custom(ctx)) {
609+
rc = dbmem_context_custom_compute(ctx, query, (int)strlen(query), &result);
610+
if (rc != 0) {
611+
sqlvTab->zErrMsg = sqlite3_mprintf("%s", dbmem_context_errmsg(ctx));
612+
return SQLITE_ERROR;
613+
}
614+
}
615+
608616
#ifndef DBMEM_OMIT_LOCAL_ENGINE
609-
if (is_local) {
617+
else if (is_local) {
610618
rc = dbmem_local_compute_embedding((dbmem_local_engine_t *)engine, query, (int)strlen(query), &result);
611619
if (rc != 0) {
612620
sqlvTab->zErrMsg = sqlite3_mprintf("%s", dbmem_context_errmsg(ctx));
613621
return SQLITE_ERROR;
614622
}
615623
}
616624
#endif
617-
625+
618626
#ifndef DBMEM_OMIT_REMOTE_ENGINE
619-
if (!is_local) {
627+
else if (!is_local) {
620628
rc = dbmem_remote_compute_embedding((dbmem_remote_engine_t *)engine, query, (int)strlen(query), &result);
621629
if (rc != 0) {
622630
sqlvTab->zErrMsg = sqlite3_mprintf("%s", dbmem_context_errmsg(ctx));
623631
return SQLITE_ERROR;
624632
}
625633
}
626634
#endif
627-
635+
628636
if (rc == SQLITE_MISUSE) {
629637
sqlvTab->zErrMsg = sqlite3_mprintf("%s", "Unable to obtain a valid embedding engine");
630638
return SQLITE_ERROR;

src/sqlite-memory.c

Lines changed: 123 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ struct dbmem_context {
7777
dbmem_local_engine_t *l_engine; // Local embedding engine (llama.cpp based)
7878
dbmem_remote_engine_t *r_engine; // Remote embedding engine (vectors.space based)
7979

80+
// Custom embedding provider
81+
dbmem_provider_t custom_provider; // User-registered callbacks
82+
char *custom_provider_name; // Provider name for matching
83+
void *custom_engine; // Opaque engine from custom_provider.init
84+
bool is_custom; // True when custom provider is active
85+
8086
// Provider configuration
8187
char *provider; // Embedding provider: "local" or remote service name
8288
char *model; // Model path (local) or model identifier (remote)
@@ -564,10 +570,14 @@ static void dbmem_context_free (void *ptr) {
564570
if (ctx->extensions) dbmem_free(ctx->extensions);
565571
if (ctx->cache_buffer) dbmem_free(ctx->cache_buffer);
566572

573+
// custom provider
574+
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine);
575+
if (ctx->custom_provider_name) dbmem_free(ctx->custom_provider_name);
576+
567577
#ifndef DBMEM_OMIT_LOCAL_ENGINE
568578
if (ctx->l_engine) dbmem_local_engine_free(ctx->l_engine);
569579
#endif
570-
580+
571581
#ifndef DBMEM_OMIT_REMOTE_ENGINE
572582
if (ctx->r_engine) dbmem_remote_engine_free(ctx->r_engine);
573583
#endif
@@ -584,10 +594,29 @@ static void dbmem_context_reset_temp_values (dbmem_context *ctx) {
584594
}
585595

586596
void *dbmem_context_engine (dbmem_context *ctx, bool *is_local) {
597+
if (ctx->is_custom) {
598+
if (is_local) *is_local = false;
599+
return ctx->custom_engine;
600+
}
587601
if (is_local) *is_local = ctx->is_local;
588602
return (ctx->is_local) ? (void *)ctx->l_engine : (void *)ctx->r_engine;
589603
}
590604

605+
bool dbmem_context_is_custom (dbmem_context *ctx) {
606+
return ctx->is_custom;
607+
}
608+
609+
int dbmem_context_custom_compute (dbmem_context *ctx, const char *text, int text_len, embedding_result_t *result) {
610+
dbmem_embedding_result_t cr = {0};
611+
int rc = ctx->custom_provider.compute(ctx->custom_engine, text, text_len, &cr);
612+
if (rc != 0) return rc;
613+
result->n_tokens = cr.n_tokens;
614+
result->n_tokens_truncated = cr.n_tokens_truncated;
615+
result->n_embd = cr.n_embd;
616+
result->embedding = cr.embedding;
617+
return 0;
618+
}
619+
591620
bool dbmem_context_load_vector (dbmem_context *ctx) {
592621
if (ctx->vector_extension_available) return true;
593622

@@ -898,56 +927,80 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value *
898927
}
899928

900929
bool is_local_provider = (strcasecmp(provider, DBMEM_LOCAL_PROVIDER) == 0);
901-
#ifdef DBMEM_OMIT_LOCAL_ENGINE
902-
if (is_local_provider) {
903-
sqlite3_result_error(context, "Local provider cannot be set because SQLite-memory was compiled without local provider support", SQLITE_ERROR);
904-
return;
930+
931+
// check if a custom provider matches
932+
bool is_custom_provider = (ctx->custom_provider_name && ctx->custom_provider.compute &&
933+
strcasecmp(provider, ctx->custom_provider_name) == 0);
934+
935+
if (!is_custom_provider) {
936+
#ifdef DBMEM_OMIT_LOCAL_ENGINE
937+
if (is_local_provider) {
938+
sqlite3_result_error(context, "Local provider cannot be set because SQLite-memory was compiled without local provider support", SQLITE_ERROR);
939+
return;
940+
}
941+
#endif
942+
#ifdef DBMEM_OMIT_REMOTE_ENGINE
943+
if (!is_local_provider) {
944+
sqlite3_result_error(context, "Remote provider cannot be set because SQLite-memory was compiled without remote provider support", SQLITE_ERROR);
945+
return;
946+
}
947+
#endif
905948
}
906-
#endif
907-
#ifdef DBMEM_OMIT_REMOTE_ENGINE
908-
if (!is_local_provider) {
909-
sqlite3_result_error(context, "Remote provider cannot be set because SQLite-memory was compiled without remote provider support", SQLITE_ERROR);
910-
return;
949+
950+
// custom provider path
951+
if (is_custom_provider) {
952+
// free previous custom engine if any
953+
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine);
954+
ctx->custom_engine = NULL;
955+
956+
ctx->custom_engine = ctx->custom_provider.init(model, ctx->api_key, ctx->error_msg);
957+
if (ctx->custom_engine == NULL) {
958+
sqlite3_result_error(context, ctx->error_msg, -1);
959+
return;
960+
}
961+
ctx->is_custom = true;
962+
ctx->is_local = false;
911963
}
912-
#endif
913-
964+
914965
// if provider is local then make sure model file exists
915966
#ifndef DBMEM_OMIT_LOCAL_ENGINE
916-
if (is_local_provider) {
967+
if (!is_custom_provider && is_local_provider) {
917968
if (dbmem_file_exists(model) == false) {
918969
sqlite3_result_error(context, "Local model not found in the specified path", SQLITE_ERROR);
919970
return;
920971
}
921-
972+
922973
if (ctx->l_engine) dbmem_local_engine_free(ctx->l_engine);
923974
ctx->l_engine = NULL;
924-
975+
925976
ctx->l_engine = dbmem_local_engine_init(ctx, model, ctx->error_msg);
926977
if (ctx->l_engine == NULL) {
927978
sqlite3_result_error(context, ctx->error_msg, -1);
928979
return;
929980
}
930-
981+
931982
if (ctx->engine_warmup) {
932983
dbmem_local_engine_warmup(ctx->l_engine);
933984
}
934-
985+
935986
ctx->is_local = true;
987+
ctx->is_custom = false;
936988
}
937989
#endif
938-
990+
939991
#ifndef DBMEM_OMIT_REMOTE_ENGINE
940-
if (!is_local_provider) {
992+
if (!is_custom_provider && !is_local_provider) {
941993
if (ctx->r_engine) dbmem_remote_engine_free(ctx->r_engine);
942994
ctx->r_engine = NULL;
943-
995+
944996
ctx->r_engine = dbmem_remote_engine_init(ctx, provider, model, ctx->error_msg);
945997
if (ctx->r_engine == NULL) {
946998
sqlite3_result_error(context, ctx->error_msg, -1);
947999
return;
9481000
}
949-
1001+
9501002
ctx->is_local = false;
1003+
ctx->is_custom = false;
9511004
}
9521005
#endif
9531006

@@ -1185,7 +1238,12 @@ static int dbmem_process_callback (const char *text, size_t len, size_t offset,
11851238

11861239
if (!cache_hit) {
11871240
// compute embedding
1188-
if (ctx->is_local) {
1241+
if (ctx->is_custom) {
1242+
rc = dbmem_context_custom_compute(ctx, text, (int)len, &result);
1243+
if (rc != 0) return rc;
1244+
}
1245+
1246+
else if (ctx->is_local) {
11891247
#ifndef DBMEM_OMIT_LOCAL_ENGINE
11901248
rc = dbmem_local_compute_embedding(ctx->l_engine, text, (int)len, &result);
11911249
if (rc != 0) return rc;
@@ -1196,7 +1254,7 @@ static int dbmem_process_callback (const char *text, size_t len, size_t offset,
11961254
#endif
11971255
}
11981256

1199-
if (!ctx->is_local) {
1257+
else {
12001258
#ifndef DBMEM_OMIT_REMOTE_ENGINE
12011259
rc = dbmem_remote_compute_embedding(ctx->r_engine, text, (int)len, &result);
12021260
if (rc != 0) return rc;
@@ -1477,6 +1535,45 @@ static void dbmem_add_directory (sqlite3_context *context, int argc, sqlite3_val
14771535

14781536
// MARK: -
14791537

1538+
#define DBMEM_CTX_POINTER_TYPE "dbmem_context_ptr"
1539+
1540+
// helper to retrieve ctx pointer (registered during init)
1541+
static void dbmem_ctx_ptr (sqlite3_context *context, int argc, sqlite3_value **argv) {
1542+
UNUSED_PARAM(argc);
1543+
UNUSED_PARAM(argv);
1544+
dbmem_context *ctx = (dbmem_context *)sqlite3_user_data(context);
1545+
sqlite3_result_pointer(context, ctx, DBMEM_CTX_POINTER_TYPE, NULL);
1546+
}
1547+
1548+
SQLITE_DBMEMORY_API int sqlite3_memory_register_provider (sqlite3 *db, const char *provider_name, const dbmem_provider_t *provider) {
1549+
if (!db || !provider_name || !provider || !provider->init || !provider->compute) return SQLITE_MISUSE;
1550+
1551+
// retrieve dbmem_context from the helper function registered during init
1552+
sqlite3_stmt *vm = NULL;
1553+
int rc = sqlite3_prepare_v2(db, "SELECT _memory_ctx_ptr()", -1, &vm, NULL);
1554+
if (rc != SQLITE_OK) return rc;
1555+
1556+
if (sqlite3_step(vm) != SQLITE_ROW) {
1557+
sqlite3_finalize(vm);
1558+
return SQLITE_ERROR;
1559+
}
1560+
dbmem_context *ctx = (dbmem_context *)sqlite3_value_pointer(sqlite3_column_value(vm, 0), DBMEM_CTX_POINTER_TYPE);
1561+
sqlite3_finalize(vm);
1562+
if (!ctx) return SQLITE_ERROR;
1563+
1564+
// free previous custom provider if any
1565+
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine);
1566+
ctx->custom_engine = NULL;
1567+
if (ctx->custom_provider_name) dbmem_free(ctx->custom_provider_name);
1568+
1569+
ctx->custom_provider_name = dbmem_strdup(provider_name);
1570+
if (!ctx->custom_provider_name) return SQLITE_NOMEM;
1571+
1572+
ctx->custom_provider = *provider;
1573+
1574+
return SQLITE_OK;
1575+
}
1576+
14801577
SQLITE_DBMEMORY_API int sqlite3_memory_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) {
14811578
#ifndef SQLITE_CORE
14821579
SQLITE_EXTENSION_INIT2(pApi);
@@ -1499,6 +1596,9 @@ SQLITE_DBMEMORY_API int sqlite3_memory_init (sqlite3 *db, char **pzErrMsg, const
14991596

15001597
rc = sqlite3_create_function_v2(db, "memory_version", 0, SQLITE_UTF8, ctx, dbmem_version, NULL, NULL, dbmem_context_free);
15011598
if (rc != SQLITE_OK) return rc;
1599+
1600+
rc = sqlite3_create_function_v2(db, "_memory_ctx_ptr", 0, SQLITE_UTF8, ctx, dbmem_ctx_ptr, NULL, NULL, NULL);
1601+
if (rc != SQLITE_OK) return rc;
15021602

15031603
rc = sqlite3_create_function_v2(db, "memory_set_option", 2, SQLITE_UTF8, ctx, dbmem_set_option, NULL, NULL, NULL);
15041604
if (rc != SQLITE_OK) return rc;

src/sqlite-memory.h

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,44 @@
2626
extern "C" {
2727
#endif
2828

29-
#define SQLITE_DBMEMORY_VERSION "0.7.5"
29+
#define SQLITE_DBMEMORY_VERSION "0.8.0"
3030

3131
// public API
3232
SQLITE_DBMEMORY_API int sqlite3_memory_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3333

34-
// internal APIs
34+
// Custom embedding provider API
35+
// Allows registering a user-defined embedding engine that works regardless of
36+
// DBMEM_OMIT_LOCAL_ENGINE / DBMEM_OMIT_REMOTE_ENGINE compile flags.
37+
// The api_key set via memory_set_apikey() is passed to the init callback.
3538
typedef struct dbmem_context dbmem_context;
3639

40+
typedef struct {
41+
int n_tokens;
42+
int n_tokens_truncated;
43+
int n_embd;
44+
float *embedding; // Engine-owned buffer, valid until next call or free
45+
} dbmem_embedding_result_t;
46+
47+
typedef struct {
48+
// Called when memory_set_model(provider, model) matches this provider.
49+
// api_key is the value set via memory_set_apikey() (may be NULL).
50+
// Return opaque engine pointer, or NULL on error (fill err_msg).
51+
void *(*init)(const char *model, const char *api_key, char err_msg[1024]);
52+
53+
// Compute embedding for text. Return 0 on success, non-zero on error.
54+
int (*compute)(void *engine, const char *text, int text_len, dbmem_embedding_result_t *result);
55+
56+
// Free the engine. Called on context teardown or model change. May be NULL.
57+
void (*free)(void *engine);
58+
} dbmem_provider_t;
59+
60+
// Register a custom embedding provider.
61+
// provider_name: matched against the first argument of memory_set_model().
62+
// Returns SQLITE_OK on success.
63+
SQLITE_DBMEMORY_API int sqlite3_memory_register_provider (sqlite3 *db, const char *provider_name, const dbmem_provider_t *provider);
64+
3765
void *dbmem_context_engine (dbmem_context *ctx, bool *is_local);
66+
bool dbmem_context_is_custom (dbmem_context *ctx);
3867
bool dbmem_context_load_vector (dbmem_context *ctx);
3968
bool dbmem_context_load_sync (dbmem_context *ctx);
4069
bool dbmem_context_perform_fts (dbmem_context *ctx);

0 commit comments

Comments
 (0)