@@ -951,6 +951,7 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value *
951951
952952 // retrieve context
953953 dbmem_context * ctx = (dbmem_context * )sqlite3_user_data (context );
954+ sqlite3 * db = sqlite3_context_db_handle (context );
954955
955956 // detect model change (only if a model was previously configured)
956957 bool model_changed = false;
@@ -979,80 +980,159 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value *
979980 #endif
980981 }
981982
983+ char * new_provider = dbmem_strdup (provider );
984+ char * new_model = dbmem_strdup (model );
985+ if (!new_provider || !new_model ) {
986+ if (new_provider ) dbmemory_free (new_provider );
987+ if (new_model ) dbmemory_free (new_model );
988+ sqlite3_result_error_nomem (context );
989+ return ;
990+ }
991+
992+ char * old_provider = ctx -> provider ;
993+ char * old_model = ctx -> model ;
994+ bool old_is_local = ctx -> is_local ;
995+ bool old_is_custom = ctx -> is_custom ;
996+
997+ #ifndef DBMEM_OMIT_LOCAL_ENGINE
998+ dbmem_local_engine_t * old_l_engine = ctx -> l_engine ;
999+ dbmem_local_engine_t * new_l_engine = ctx -> l_engine ;
1000+ #endif
1001+
1002+ #ifndef DBMEM_OMIT_REMOTE_ENGINE
1003+ dbmem_remote_engine_t * old_r_engine = ctx -> r_engine ;
1004+ dbmem_remote_engine_t * new_r_engine = ctx -> r_engine ;
1005+ #endif
1006+
1007+ void * old_custom_engine = ctx -> custom_engine ;
1008+ void * new_custom_engine = ctx -> custom_engine ;
1009+ bool set_model_started = false;
1010+ int rc = SQLITE_OK ;
1011+
9821012 // custom provider path
9831013 if (is_custom_provider ) {
984- // free previous custom engine if any
985- if (ctx -> custom_engine && ctx -> custom_provider .free ) ctx -> custom_provider .free (ctx -> custom_engine , ctx -> custom_provider .xdata );
986- ctx -> custom_engine = NULL ;
987-
988- ctx -> custom_engine = ctx -> custom_provider .init (model , ctx -> api_key , ctx -> custom_provider .xdata , ctx -> error_msg );
989- if (ctx -> custom_engine == NULL ) {
1014+ new_custom_engine = ctx -> custom_provider .init (model , ctx -> api_key , ctx -> custom_provider .xdata , ctx -> error_msg );
1015+ if (new_custom_engine == NULL ) {
1016+ dbmemory_free (new_provider );
1017+ dbmemory_free (new_model );
9901018 sqlite3_result_error (context , ctx -> error_msg , -1 );
9911019 return ;
9921020 }
993- ctx -> is_custom = true;
994- ctx -> is_local = false;
9951021 }
9961022
9971023 // if provider is local then make sure model file exists
9981024 #ifndef DBMEM_OMIT_LOCAL_ENGINE
9991025 if (!is_custom_provider && is_local_provider ) {
10001026 if (dbmem_file_exists (model ) == false) {
1027+ dbmemory_free (new_provider );
1028+ dbmemory_free (new_model );
10011029 sqlite3_result_error (context , "Local model not found in the specified path" , SQLITE_ERROR );
10021030 return ;
10031031 }
10041032
1005- if (ctx -> l_engine ) dbmem_local_engine_free (ctx -> l_engine );
1006- ctx -> l_engine = NULL ;
1007-
1008- ctx -> l_engine = dbmem_local_engine_init (ctx , model , ctx -> error_msg );
1009- if (ctx -> l_engine == NULL ) {
1033+ new_l_engine = dbmem_local_engine_init (ctx , model , ctx -> error_msg );
1034+ if (new_l_engine == NULL ) {
1035+ dbmemory_free (new_provider );
1036+ dbmemory_free (new_model );
10101037 sqlite3_result_error (context , ctx -> error_msg , -1 );
10111038 return ;
10121039 }
10131040
10141041 if (ctx -> engine_warmup ) {
1015- dbmem_local_engine_warmup (ctx -> l_engine );
1042+ dbmem_local_engine_warmup (new_l_engine );
10161043 }
1017-
1018- ctx -> is_local = true;
1019- ctx -> is_custom = false;
10201044 }
10211045 #endif
10221046
10231047 #ifndef DBMEM_OMIT_REMOTE_ENGINE
10241048 if (!is_custom_provider && !is_local_provider ) {
1025- if (ctx -> r_engine ) dbmem_remote_engine_free (ctx -> r_engine );
1026- ctx -> r_engine = NULL ;
1027-
1028- ctx -> r_engine = dbmem_remote_engine_init (ctx , provider , model , ctx -> error_msg );
1029- if (ctx -> r_engine == NULL ) {
1049+ new_r_engine = dbmem_remote_engine_init (ctx , provider , model , ctx -> error_msg );
1050+ if (new_r_engine == NULL ) {
1051+ dbmemory_free (new_provider );
1052+ dbmemory_free (new_model );
10301053 sqlite3_result_error (context , ctx -> error_msg , -1 );
10311054 return ;
10321055 }
1033-
1034- ctx -> is_local = false;
1035- ctx -> is_custom = false;
10361056 }
10371057 #endif
1038-
1058+
1059+ ctx -> provider = new_provider ;
1060+ ctx -> model = new_model ;
1061+ ctx -> is_local = is_custom_provider ? false : is_local_provider ;
1062+ ctx -> is_custom = is_custom_provider ;
1063+ ctx -> custom_engine = new_custom_engine ;
1064+ #ifndef DBMEM_OMIT_LOCAL_ENGINE
1065+ ctx -> l_engine = new_l_engine ;
1066+ #endif
1067+ #ifndef DBMEM_OMIT_REMOTE_ENGINE
1068+ ctx -> r_engine = new_r_engine ;
1069+ #endif
1070+
1071+ rc = sqlite3_exec (db , "SAVEPOINT dbmem_set_model;" , NULL , NULL , NULL );
1072+ if (rc == SQLITE_OK ) set_model_started = true;
1073+
10391074 // update settings
1040- sqlite3 * db = sqlite3_context_db_handle (context );
1041- int rc = dbmem_settings_write_text (db , DBMEM_SETTINGS_KEY_PROVIDER , provider );
1042- if (rc == SQLITE_OK ) rc = dbmem_settings_write_text (db , DBMEM_SETTINGS_KEY_MODEL , model );
1043-
1044- // sync settings
1045- if (rc == SQLITE_OK ) {
1046- dbmem_settings_sync (ctx , DBMEM_SETTINGS_KEY_PROVIDER , argv [0 ]);
1047- dbmem_settings_sync (ctx , DBMEM_SETTINGS_KEY_MODEL , argv [1 ]);
1048- }
1075+ if (rc == SQLITE_OK ) rc = dbmem_settings_write_text (db , DBMEM_SETTINGS_KEY_PROVIDER , provider );
1076+ if (rc == SQLITE_OK ) rc = dbmem_settings_write_text (db , DBMEM_SETTINGS_KEY_MODEL , model );
10491077
10501078 // reindex all content if the model changed
10511079 if (model_changed && rc == SQLITE_OK ) {
10521080 rc = dbmem_reindex (ctx );
10531081 }
10541082
1055- (rc == SQLITE_OK ) ? sqlite3_result_int (context , 1 ) : sqlite3_result_error (context , sqlite3_errmsg (db ), -1 );
1083+ if (rc == SQLITE_OK && set_model_started ) {
1084+ rc = sqlite3_exec (db , "RELEASE dbmem_set_model;" , NULL , NULL , NULL );
1085+ set_model_started = false;
1086+ }
1087+
1088+ if (rc != SQLITE_OK ) {
1089+ if (set_model_started ) {
1090+ sqlite3_exec (db , "ROLLBACK TO dbmem_set_model; RELEASE dbmem_set_model;" , NULL , NULL , NULL );
1091+ }
1092+
1093+ ctx -> provider = old_provider ;
1094+ ctx -> model = old_model ;
1095+ ctx -> is_local = old_is_local ;
1096+ ctx -> is_custom = old_is_custom ;
1097+ ctx -> custom_engine = old_custom_engine ;
1098+ #ifndef DBMEM_OMIT_LOCAL_ENGINE
1099+ ctx -> l_engine = old_l_engine ;
1100+ if (!is_custom_provider && is_local_provider && new_l_engine != old_l_engine && new_l_engine ) {
1101+ dbmem_local_engine_free (new_l_engine );
1102+ }
1103+ #endif
1104+ #ifndef DBMEM_OMIT_REMOTE_ENGINE
1105+ ctx -> r_engine = old_r_engine ;
1106+ if (!is_custom_provider && !is_local_provider && new_r_engine != old_r_engine && new_r_engine ) {
1107+ dbmem_remote_engine_free (new_r_engine );
1108+ }
1109+ #endif
1110+ if (is_custom_provider && new_custom_engine != old_custom_engine && new_custom_engine && ctx -> custom_provider .free ) {
1111+ ctx -> custom_provider .free (new_custom_engine , ctx -> custom_provider .xdata );
1112+ }
1113+ dbmemory_free (new_provider );
1114+ dbmemory_free (new_model );
1115+ sqlite3_result_error (context , ctx -> error_msg [0 ] ? ctx -> error_msg : sqlite3_errmsg (db ), -1 );
1116+ return ;
1117+ }
1118+
1119+ if (old_provider ) dbmemory_free (old_provider );
1120+ if (old_model ) dbmemory_free (old_model );
1121+ #ifndef DBMEM_OMIT_LOCAL_ENGINE
1122+ if (!is_custom_provider && is_local_provider && old_l_engine && old_l_engine != new_l_engine ) {
1123+ dbmem_local_engine_free (old_l_engine );
1124+ }
1125+ #endif
1126+ #ifndef DBMEM_OMIT_REMOTE_ENGINE
1127+ if (!is_custom_provider && !is_local_provider && old_r_engine && old_r_engine != new_r_engine ) {
1128+ dbmem_remote_engine_free (old_r_engine );
1129+ }
1130+ #endif
1131+ if (is_custom_provider && old_custom_engine && old_custom_engine != new_custom_engine && ctx -> custom_provider .free ) {
1132+ ctx -> custom_provider .free (old_custom_engine , ctx -> custom_provider .xdata );
1133+ }
1134+
1135+ sqlite3_result_int (context , 1 );
10561136}
10571137
10581138static void dbmem_set_apikey (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
0 commit comments