@@ -77,6 +77,12 @@ struct dbmem_context {
7777 dbmem_local_engine_t * l_engine ; // Local embedding engine (llama.cpp based)
7878 dbmem_remote_engine_t * r_engine ; // Remote embedding engine (vectors.space based)
7979
80+ // Custom embedding provider
81+ dbmem_provider_t custom_provider ; // User-registered callbacks
82+ char * custom_provider_name ; // Provider name for matching
83+ void * custom_engine ; // Opaque engine from custom_provider.init
84+ bool is_custom ; // True when custom provider is active
85+
8086 // Provider configuration
8187 char * provider ; // Embedding provider: "local" or remote service name
8288 char * model ; // Model path (local) or model identifier (remote)
@@ -564,10 +570,14 @@ static void dbmem_context_free (void *ptr) {
564570 if (ctx -> extensions ) dbmem_free (ctx -> extensions );
565571 if (ctx -> cache_buffer ) dbmem_free (ctx -> cache_buffer );
566572
573+ // custom provider
574+ if (ctx -> custom_engine && ctx -> custom_provider .free ) ctx -> custom_provider .free (ctx -> custom_engine );
575+ if (ctx -> custom_provider_name ) dbmem_free (ctx -> custom_provider_name );
576+
567577 #ifndef DBMEM_OMIT_LOCAL_ENGINE
568578 if (ctx -> l_engine ) dbmem_local_engine_free (ctx -> l_engine );
569579 #endif
570-
580+
571581 #ifndef DBMEM_OMIT_REMOTE_ENGINE
572582 if (ctx -> r_engine ) dbmem_remote_engine_free (ctx -> r_engine );
573583 #endif
@@ -584,10 +594,29 @@ static void dbmem_context_reset_temp_values (dbmem_context *ctx) {
584594}
585595
586596void * dbmem_context_engine (dbmem_context * ctx , bool * is_local ) {
597+ if (ctx -> is_custom ) {
598+ if (is_local ) * is_local = false;
599+ return ctx -> custom_engine ;
600+ }
587601 if (is_local ) * is_local = ctx -> is_local ;
588602 return (ctx -> is_local ) ? (void * )ctx -> l_engine : (void * )ctx -> r_engine ;
589603}
590604
605+ bool dbmem_context_is_custom (dbmem_context * ctx ) {
606+ return ctx -> is_custom ;
607+ }
608+
609+ int dbmem_context_custom_compute (dbmem_context * ctx , const char * text , int text_len , embedding_result_t * result ) {
610+ dbmem_embedding_result_t cr = {0 };
611+ int rc = ctx -> custom_provider .compute (ctx -> custom_engine , text , text_len , & cr );
612+ if (rc != 0 ) return rc ;
613+ result -> n_tokens = cr .n_tokens ;
614+ result -> n_tokens_truncated = cr .n_tokens_truncated ;
615+ result -> n_embd = cr .n_embd ;
616+ result -> embedding = cr .embedding ;
617+ return 0 ;
618+ }
619+
591620bool dbmem_context_load_vector (dbmem_context * ctx ) {
592621 if (ctx -> vector_extension_available ) return true;
593622
@@ -898,56 +927,80 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value *
898927 }
899928
900929 bool is_local_provider = (strcasecmp (provider , DBMEM_LOCAL_PROVIDER ) == 0 );
901- #ifdef DBMEM_OMIT_LOCAL_ENGINE
902- if (is_local_provider ) {
903- sqlite3_result_error (context , "Local provider cannot be set because SQLite-memory was compiled without local provider support" , SQLITE_ERROR );
904- return ;
930+
931+ // check if a custom provider matches
932+ bool is_custom_provider = (ctx -> custom_provider_name && ctx -> custom_provider .compute &&
933+ strcasecmp (provider , ctx -> custom_provider_name ) == 0 );
934+
935+ if (!is_custom_provider ) {
936+ #ifdef DBMEM_OMIT_LOCAL_ENGINE
937+ if (is_local_provider ) {
938+ sqlite3_result_error (context , "Local provider cannot be set because SQLite-memory was compiled without local provider support" , SQLITE_ERROR );
939+ return ;
940+ }
941+ #endif
942+ #ifdef DBMEM_OMIT_REMOTE_ENGINE
943+ if (!is_local_provider ) {
944+ sqlite3_result_error (context , "Remote provider cannot be set because SQLite-memory was compiled without remote provider support" , SQLITE_ERROR );
945+ return ;
946+ }
947+ #endif
905948 }
906- #endif
907- #ifdef DBMEM_OMIT_REMOTE_ENGINE
908- if (!is_local_provider ) {
909- sqlite3_result_error (context , "Remote provider cannot be set because SQLite-memory was compiled without remote provider support" , SQLITE_ERROR );
910- return ;
949+
950+ // custom provider path
951+ if (is_custom_provider ) {
952+ // free previous custom engine if any
953+ if (ctx -> custom_engine && ctx -> custom_provider .free ) ctx -> custom_provider .free (ctx -> custom_engine );
954+ ctx -> custom_engine = NULL ;
955+
956+ ctx -> custom_engine = ctx -> custom_provider .init (model , ctx -> api_key , ctx -> error_msg );
957+ if (ctx -> custom_engine == NULL ) {
958+ sqlite3_result_error (context , ctx -> error_msg , -1 );
959+ return ;
960+ }
961+ ctx -> is_custom = true;
962+ ctx -> is_local = false;
911963 }
912- #endif
913-
964+
914965 // if provider is local then make sure model file exists
915966 #ifndef DBMEM_OMIT_LOCAL_ENGINE
916- if (is_local_provider ) {
967+ if (! is_custom_provider && is_local_provider ) {
917968 if (dbmem_file_exists (model ) == false) {
918969 sqlite3_result_error (context , "Local model not found in the specified path" , SQLITE_ERROR );
919970 return ;
920971 }
921-
972+
922973 if (ctx -> l_engine ) dbmem_local_engine_free (ctx -> l_engine );
923974 ctx -> l_engine = NULL ;
924-
975+
925976 ctx -> l_engine = dbmem_local_engine_init (ctx , model , ctx -> error_msg );
926977 if (ctx -> l_engine == NULL ) {
927978 sqlite3_result_error (context , ctx -> error_msg , -1 );
928979 return ;
929980 }
930-
981+
931982 if (ctx -> engine_warmup ) {
932983 dbmem_local_engine_warmup (ctx -> l_engine );
933984 }
934-
985+
935986 ctx -> is_local = true;
987+ ctx -> is_custom = false;
936988 }
937989 #endif
938-
990+
939991 #ifndef DBMEM_OMIT_REMOTE_ENGINE
940- if (!is_local_provider ) {
992+ if (!is_custom_provider && ! is_local_provider ) {
941993 if (ctx -> r_engine ) dbmem_remote_engine_free (ctx -> r_engine );
942994 ctx -> r_engine = NULL ;
943-
995+
944996 ctx -> r_engine = dbmem_remote_engine_init (ctx , provider , model , ctx -> error_msg );
945997 if (ctx -> r_engine == NULL ) {
946998 sqlite3_result_error (context , ctx -> error_msg , -1 );
947999 return ;
9481000 }
949-
1001+
9501002 ctx -> is_local = false;
1003+ ctx -> is_custom = false;
9511004 }
9521005 #endif
9531006
@@ -1185,7 +1238,12 @@ static int dbmem_process_callback (const char *text, size_t len, size_t offset,
11851238
11861239 if (!cache_hit ) {
11871240 // compute embedding
1188- if (ctx -> is_local ) {
1241+ if (ctx -> is_custom ) {
1242+ rc = dbmem_context_custom_compute (ctx , text , (int )len , & result );
1243+ if (rc != 0 ) return rc ;
1244+ }
1245+
1246+ else if (ctx -> is_local ) {
11891247 #ifndef DBMEM_OMIT_LOCAL_ENGINE
11901248 rc = dbmem_local_compute_embedding (ctx -> l_engine , text , (int )len , & result );
11911249 if (rc != 0 ) return rc ;
@@ -1196,7 +1254,7 @@ static int dbmem_process_callback (const char *text, size_t len, size_t offset,
11961254 #endif
11971255 }
11981256
1199- if (! ctx -> is_local ) {
1257+ else {
12001258 #ifndef DBMEM_OMIT_REMOTE_ENGINE
12011259 rc = dbmem_remote_compute_embedding (ctx -> r_engine , text , (int )len , & result );
12021260 if (rc != 0 ) return rc ;
@@ -1477,6 +1535,45 @@ static void dbmem_add_directory (sqlite3_context *context, int argc, sqlite3_val
14771535
14781536// MARK: -
14791537
1538+ #define DBMEM_CTX_POINTER_TYPE "dbmem_context_ptr"
1539+
1540+ // helper to retrieve ctx pointer (registered during init)
1541+ static void dbmem_ctx_ptr (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
1542+ UNUSED_PARAM (argc );
1543+ UNUSED_PARAM (argv );
1544+ dbmem_context * ctx = (dbmem_context * )sqlite3_user_data (context );
1545+ sqlite3_result_pointer (context , ctx , DBMEM_CTX_POINTER_TYPE , NULL );
1546+ }
1547+
1548+ SQLITE_DBMEMORY_API int sqlite3_memory_register_provider (sqlite3 * db , const char * provider_name , const dbmem_provider_t * provider ) {
1549+ if (!db || !provider_name || !provider || !provider -> init || !provider -> compute ) return SQLITE_MISUSE ;
1550+
1551+ // retrieve dbmem_context from the helper function registered during init
1552+ sqlite3_stmt * vm = NULL ;
1553+ int rc = sqlite3_prepare_v2 (db , "SELECT _memory_ctx_ptr()" , -1 , & vm , NULL );
1554+ if (rc != SQLITE_OK ) return rc ;
1555+
1556+ if (sqlite3_step (vm ) != SQLITE_ROW ) {
1557+ sqlite3_finalize (vm );
1558+ return SQLITE_ERROR ;
1559+ }
1560+ dbmem_context * ctx = (dbmem_context * )sqlite3_value_pointer (sqlite3_column_value (vm , 0 ), DBMEM_CTX_POINTER_TYPE );
1561+ sqlite3_finalize (vm );
1562+ if (!ctx ) return SQLITE_ERROR ;
1563+
1564+ // free previous custom provider if any
1565+ if (ctx -> custom_engine && ctx -> custom_provider .free ) ctx -> custom_provider .free (ctx -> custom_engine );
1566+ ctx -> custom_engine = NULL ;
1567+ if (ctx -> custom_provider_name ) dbmem_free (ctx -> custom_provider_name );
1568+
1569+ ctx -> custom_provider_name = dbmem_strdup (provider_name );
1570+ if (!ctx -> custom_provider_name ) return SQLITE_NOMEM ;
1571+
1572+ ctx -> custom_provider = * provider ;
1573+
1574+ return SQLITE_OK ;
1575+ }
1576+
14801577SQLITE_DBMEMORY_API int sqlite3_memory_init (sqlite3 * db , char * * pzErrMsg , const sqlite3_api_routines * pApi ) {
14811578 #ifndef SQLITE_CORE
14821579 SQLITE_EXTENSION_INIT2 (pApi );
@@ -1499,6 +1596,9 @@ SQLITE_DBMEMORY_API int sqlite3_memory_init (sqlite3 *db, char **pzErrMsg, const
14991596
15001597 rc = sqlite3_create_function_v2 (db , "memory_version" , 0 , SQLITE_UTF8 , ctx , dbmem_version , NULL , NULL , dbmem_context_free );
15011598 if (rc != SQLITE_OK ) return rc ;
1599+
1600+ rc = sqlite3_create_function_v2 (db , "_memory_ctx_ptr" , 0 , SQLITE_UTF8 , ctx , dbmem_ctx_ptr , NULL , NULL , NULL );
1601+ if (rc != SQLITE_OK ) return rc ;
15021602
15031603 rc = sqlite3_create_function_v2 (db , "memory_set_option" , 2 , SQLITE_UTF8 , ctx , dbmem_set_option , NULL , NULL , NULL );
15041604 if (rc != SQLITE_OK ) return rc ;
0 commit comments