Skip to content

Commit f47d2f9

Browse files
committed
memory_set_model is now an atomic operation
1 parent 17d30e6 commit f47d2f9

2 files changed

Lines changed: 126 additions & 36 deletions

File tree

src/sqlite-memory.c

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value *
951951

952952
// retrieve context
953953
dbmem_context *ctx = (dbmem_context *)sqlite3_user_data(context);
954+
sqlite3 *db = sqlite3_context_db_handle(context);
954955

955956
// detect model change (only if a model was previously configured)
956957
bool model_changed = false;
@@ -979,80 +980,159 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value *
979980
#endif
980981
}
981982

983+
char *new_provider = dbmem_strdup(provider);
984+
char *new_model = dbmem_strdup(model);
985+
if (!new_provider || !new_model) {
986+
if (new_provider) dbmemory_free(new_provider);
987+
if (new_model) dbmemory_free(new_model);
988+
sqlite3_result_error_nomem(context);
989+
return;
990+
}
991+
992+
char *old_provider = ctx->provider;
993+
char *old_model = ctx->model;
994+
bool old_is_local = ctx->is_local;
995+
bool old_is_custom = ctx->is_custom;
996+
997+
#ifndef DBMEM_OMIT_LOCAL_ENGINE
998+
dbmem_local_engine_t *old_l_engine = ctx->l_engine;
999+
dbmem_local_engine_t *new_l_engine = ctx->l_engine;
1000+
#endif
1001+
1002+
#ifndef DBMEM_OMIT_REMOTE_ENGINE
1003+
dbmem_remote_engine_t *old_r_engine = ctx->r_engine;
1004+
dbmem_remote_engine_t *new_r_engine = ctx->r_engine;
1005+
#endif
1006+
1007+
void *old_custom_engine = ctx->custom_engine;
1008+
void *new_custom_engine = ctx->custom_engine;
1009+
bool set_model_started = false;
1010+
int rc = SQLITE_OK;
1011+
9821012
// custom provider path
9831013
if (is_custom_provider) {
984-
// free previous custom engine if any
985-
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine, ctx->custom_provider.xdata);
986-
ctx->custom_engine = NULL;
987-
988-
ctx->custom_engine = ctx->custom_provider.init(model, ctx->api_key, ctx->custom_provider.xdata, ctx->error_msg);
989-
if (ctx->custom_engine == NULL) {
1014+
new_custom_engine = ctx->custom_provider.init(model, ctx->api_key, ctx->custom_provider.xdata, ctx->error_msg);
1015+
if (new_custom_engine == NULL) {
1016+
dbmemory_free(new_provider);
1017+
dbmemory_free(new_model);
9901018
sqlite3_result_error(context, ctx->error_msg, -1);
9911019
return;
9921020
}
993-
ctx->is_custom = true;
994-
ctx->is_local = false;
9951021
}
9961022

9971023
// if provider is local then make sure model file exists
9981024
#ifndef DBMEM_OMIT_LOCAL_ENGINE
9991025
if (!is_custom_provider && is_local_provider) {
10001026
if (dbmem_file_exists(model) == false) {
1027+
dbmemory_free(new_provider);
1028+
dbmemory_free(new_model);
10011029
sqlite3_result_error(context, "Local model not found in the specified path", SQLITE_ERROR);
10021030
return;
10031031
}
10041032

1005-
if (ctx->l_engine) dbmem_local_engine_free(ctx->l_engine);
1006-
ctx->l_engine = NULL;
1007-
1008-
ctx->l_engine = dbmem_local_engine_init(ctx, model, ctx->error_msg);
1009-
if (ctx->l_engine == NULL) {
1033+
new_l_engine = dbmem_local_engine_init(ctx, model, ctx->error_msg);
1034+
if (new_l_engine == NULL) {
1035+
dbmemory_free(new_provider);
1036+
dbmemory_free(new_model);
10101037
sqlite3_result_error(context, ctx->error_msg, -1);
10111038
return;
10121039
}
10131040

10141041
if (ctx->engine_warmup) {
1015-
dbmem_local_engine_warmup(ctx->l_engine);
1042+
dbmem_local_engine_warmup(new_l_engine);
10161043
}
1017-
1018-
ctx->is_local = true;
1019-
ctx->is_custom = false;
10201044
}
10211045
#endif
10221046

10231047
#ifndef DBMEM_OMIT_REMOTE_ENGINE
10241048
if (!is_custom_provider && !is_local_provider) {
1025-
if (ctx->r_engine) dbmem_remote_engine_free(ctx->r_engine);
1026-
ctx->r_engine = NULL;
1027-
1028-
ctx->r_engine = dbmem_remote_engine_init(ctx, provider, model, ctx->error_msg);
1029-
if (ctx->r_engine == NULL) {
1049+
new_r_engine = dbmem_remote_engine_init(ctx, provider, model, ctx->error_msg);
1050+
if (new_r_engine == NULL) {
1051+
dbmemory_free(new_provider);
1052+
dbmemory_free(new_model);
10301053
sqlite3_result_error(context, ctx->error_msg, -1);
10311054
return;
10321055
}
1033-
1034-
ctx->is_local = false;
1035-
ctx->is_custom = false;
10361056
}
10371057
#endif
1038-
1058+
1059+
ctx->provider = new_provider;
1060+
ctx->model = new_model;
1061+
ctx->is_local = is_custom_provider ? false : is_local_provider;
1062+
ctx->is_custom = is_custom_provider;
1063+
ctx->custom_engine = new_custom_engine;
1064+
#ifndef DBMEM_OMIT_LOCAL_ENGINE
1065+
ctx->l_engine = new_l_engine;
1066+
#endif
1067+
#ifndef DBMEM_OMIT_REMOTE_ENGINE
1068+
ctx->r_engine = new_r_engine;
1069+
#endif
1070+
1071+
rc = sqlite3_exec(db, "SAVEPOINT dbmem_set_model;", NULL, NULL, NULL);
1072+
if (rc == SQLITE_OK) set_model_started = true;
1073+
10391074
// update settings
1040-
sqlite3 *db = sqlite3_context_db_handle(context);
1041-
int rc = dbmem_settings_write_text (db, DBMEM_SETTINGS_KEY_PROVIDER, provider);
1042-
if (rc == SQLITE_OK) rc = dbmem_settings_write_text (db, DBMEM_SETTINGS_KEY_MODEL, model);
1043-
1044-
// sync settings
1045-
if (rc == SQLITE_OK) {
1046-
dbmem_settings_sync(ctx, DBMEM_SETTINGS_KEY_PROVIDER, argv[0]);
1047-
dbmem_settings_sync(ctx, DBMEM_SETTINGS_KEY_MODEL, argv[1]);
1048-
}
1075+
if (rc == SQLITE_OK) rc = dbmem_settings_write_text(db, DBMEM_SETTINGS_KEY_PROVIDER, provider);
1076+
if (rc == SQLITE_OK) rc = dbmem_settings_write_text(db, DBMEM_SETTINGS_KEY_MODEL, model);
10491077

10501078
// reindex all content if the model changed
10511079
if (model_changed && rc == SQLITE_OK) {
10521080
rc = dbmem_reindex(ctx);
10531081
}
10541082

1055-
(rc == SQLITE_OK) ? sqlite3_result_int(context, 1) : sqlite3_result_error(context, sqlite3_errmsg(db), -1);
1083+
if (rc == SQLITE_OK && set_model_started) {
1084+
rc = sqlite3_exec(db, "RELEASE dbmem_set_model;", NULL, NULL, NULL);
1085+
set_model_started = false;
1086+
}
1087+
1088+
if (rc != SQLITE_OK) {
1089+
if (set_model_started) {
1090+
sqlite3_exec(db, "ROLLBACK TO dbmem_set_model; RELEASE dbmem_set_model;", NULL, NULL, NULL);
1091+
}
1092+
1093+
ctx->provider = old_provider;
1094+
ctx->model = old_model;
1095+
ctx->is_local = old_is_local;
1096+
ctx->is_custom = old_is_custom;
1097+
ctx->custom_engine = old_custom_engine;
1098+
#ifndef DBMEM_OMIT_LOCAL_ENGINE
1099+
ctx->l_engine = old_l_engine;
1100+
if (!is_custom_provider && is_local_provider && new_l_engine != old_l_engine && new_l_engine) {
1101+
dbmem_local_engine_free(new_l_engine);
1102+
}
1103+
#endif
1104+
#ifndef DBMEM_OMIT_REMOTE_ENGINE
1105+
ctx->r_engine = old_r_engine;
1106+
if (!is_custom_provider && !is_local_provider && new_r_engine != old_r_engine && new_r_engine) {
1107+
dbmem_remote_engine_free(new_r_engine);
1108+
}
1109+
#endif
1110+
if (is_custom_provider && new_custom_engine != old_custom_engine && new_custom_engine && ctx->custom_provider.free) {
1111+
ctx->custom_provider.free(new_custom_engine, ctx->custom_provider.xdata);
1112+
}
1113+
dbmemory_free(new_provider);
1114+
dbmemory_free(new_model);
1115+
sqlite3_result_error(context, ctx->error_msg[0] ? ctx->error_msg : sqlite3_errmsg(db), -1);
1116+
return;
1117+
}
1118+
1119+
if (old_provider) dbmemory_free(old_provider);
1120+
if (old_model) dbmemory_free(old_model);
1121+
#ifndef DBMEM_OMIT_LOCAL_ENGINE
1122+
if (!is_custom_provider && is_local_provider && old_l_engine && old_l_engine != new_l_engine) {
1123+
dbmem_local_engine_free(old_l_engine);
1124+
}
1125+
#endif
1126+
#ifndef DBMEM_OMIT_REMOTE_ENGINE
1127+
if (!is_custom_provider && !is_local_provider && old_r_engine && old_r_engine != new_r_engine) {
1128+
dbmem_remote_engine_free(old_r_engine);
1129+
}
1130+
#endif
1131+
if (is_custom_provider && old_custom_engine && old_custom_engine != new_custom_engine && ctx->custom_provider.free) {
1132+
ctx->custom_provider.free(old_custom_engine, ctx->custom_provider.xdata);
1133+
}
1134+
1135+
sqlite3_result_int(context, 1);
10561136
}
10571137

10581138
static void dbmem_set_apikey (sqlite3_context *context, int argc, sqlite3_value **argv) {

test/unittest.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2566,6 +2566,16 @@ TEST(sqlite_set_model_failed_reindex_preserves_existing_rows) {
25662566
ASSERT_EQ(rc, SQLITE_OK);
25672567
ASSERT_STR_EQ(context, "keep");
25682568

2569+
char provider[64];
2570+
rc = exec_get_text(db, "SELECT memory_get_option('provider');", provider, sizeof(provider));
2571+
ASSERT_EQ(rc, SQLITE_OK);
2572+
ASSERT_STR_EQ(provider, "dummy");
2573+
2574+
char model[64];
2575+
rc = exec_get_text(db, "SELECT memory_get_option('model');", model, sizeof(model));
2576+
ASSERT_EQ(rc, SQLITE_OK);
2577+
ASSERT_STR_EQ(model, "test-model");
2578+
25692579
sqlite3_close(db);
25702580
}
25712581

0 commit comments

Comments
 (0)