Skip to content

Commit 86e33fc

Browse files
committed
Merge branch 'pr-24'
# Conflicts: # src/sqlite-ai.h # tests/c/unittest.c
2 parents 2b30aca + fcf8d10 commit 86e33fc

File tree

2 files changed

+144
-5
lines changed

2 files changed

+144
-5
lines changed

src/sqlite-ai.c

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,20 +2149,44 @@ static void llm_chat_save (sqlite3_context *context, int argc, sqlite3_value **a
21492149
// start transaction
21502150
sqlite_db_write_simple(context, db, "BEGIN;");
21512151

2152-
// save chat
2153-
const char *sql = "INSERT INTO ai_chat_history (uuid, title, metadata) VALUES (?, ?, ?);";
2152+
// save chat, the ON CONFLICT allows saving multiple times
2153+
const char *sql = "INSERT INTO ai_chat_history (uuid, title, metadata) VALUES (?, ?, ?) "
2154+
"ON CONFLICT(uuid) DO UPDATE SET "
2155+
" title = excluded.title, "
2156+
" metadata = excluded.metadata, "
2157+
" created_at = CURRENT_TIMESTAMP;";
21542158
const char *values[] = {ai->chat.uuid, title, meta};
21552159
int types[] = {SQLITE_TEXT, SQLITE_TEXT, SQLITE_TEXT};
21562160
int lens[] = {-1, -1, -1};
21572161

21582162
int rc = sqlite_db_write(context, db, sql, values, types, lens, 3);
21592163
if (rc != SQLITE_OK) goto abort_save;
2160-
2161-
// loop to save messages (the context)
2164+
2165+
// get the rowid, cannot use sqlite3_last_insert_rowid for the CONFLICT case
21622166
char rowid_s[256];
2163-
sqlite3_int64 rowid = sqlite3_last_insert_rowid(db);
2167+
sqlite3_stmt *pstmt = NULL;
2168+
sql = "SELECT id FROM ai_chat_history WHERE uuid = ?;";
2169+
rc = sqlite3_prepare_v2(db, sql, -1, &pstmt, NULL);
2170+
if (rc != SQLITE_OK) goto abort_save;
2171+
rc = sqlite3_bind_text(pstmt, 1, ai->chat.uuid, -1, SQLITE_STATIC);
2172+
rc = sqlite3_step(pstmt);
2173+
if (rc != SQLITE_ROW) {
2174+
sqlite3_finalize(pstmt);
2175+
goto abort_save;
2176+
}
2177+
sqlite3_int64 rowid = sqlite3_column_int64(pstmt, 0);
2178+
sqlite3_finalize(pstmt);
21642179
snprintf(rowid_s, sizeof(rowid_s), "%lld", (long long)rowid);
2180+
2181+
// delete all messages for this chat id, if any
2182+
sql = "DELETE FROM ai_chat_messages WHERE chat_id = ?;";
2183+
const char *values3[] = {rowid_s};
2184+
int types3[] = {SQLITE_INTEGER};
2185+
int lens3[] = {-1};
2186+
rc = sqlite_db_write(context, db, sql, values3, types3, lens3, 1);
2187+
if (rc != SQLITE_OK) goto abort_save;
21652188

2189+
// loop to save messages (the context)
21662190
sql = "INSERT INTO ai_chat_messages (chat_id, role, content) VALUES (?, ?, ?);";
21672191
int types2[] = {SQLITE_INTEGER, SQLITE_TEXT, SQLITE_TEXT};
21682192

tests/c/unittest.c

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,6 +1812,120 @@ static int test_audio_model_load_free_cycle(const test_env *env) {
18121812
return 1;
18131813
}
18141814

1815+
static int test_llm_chat_double_save(const test_env *env) {
1816+
sqlite3 *db = NULL;
1817+
bool model_loaded = false;
1818+
bool context_created = false;
1819+
bool chat_created = false;
1820+
int status = 1;
1821+
1822+
if (open_db_and_load(env, &db) != SQLITE_OK) {
1823+
goto done;
1824+
}
1825+
1826+
const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH;
1827+
char sqlbuf[512];
1828+
snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model);
1829+
if (exec_expect_ok(env, db, sqlbuf) != 0)
1830+
goto done;
1831+
model_loaded = true;
1832+
1833+
if (exec_expect_ok(env, db,
1834+
"SELECT llm_context_create('context_size=1000');") != 0)
1835+
goto done;
1836+
context_created = true;
1837+
1838+
if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0)
1839+
goto done;
1840+
chat_created = true;
1841+
1842+
// First prompt
1843+
const char *prompt1 = "First prompt";
1844+
if (exec_expect_ok(env, db, "SELECT llm_chat_respond('First prompt');") != 0)
1845+
goto done;
1846+
1847+
// First save
1848+
if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0)
1849+
goto done;
1850+
1851+
// Second prompt
1852+
const char *prompt2 = "Second prompt";
1853+
if (exec_expect_ok(env, db, "SELECT llm_chat_respond('Second prompt');") != 0)
1854+
goto done;
1855+
1856+
// Second save
1857+
if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0)
1858+
goto done;
1859+
1860+
ai_chat_message_row rows[8];
1861+
int count = 0;
1862+
// We expect 4 messages: User1, Assistant1, User2, Assistant2
1863+
if (fetch_ai_chat_messages(env, db, rows, 8, &count) != 0)
1864+
goto done;
1865+
1866+
if (count != 5) {
1867+
fprintf(stderr,
1868+
"[test_llm_chat_double_save] expected 4 message rows, got %d\n",
1869+
count);
1870+
goto done;
1871+
}
1872+
1873+
// Verify order and roles
1874+
if (strcmp(rows[0].role, "system") != 0 ||
1875+
strcmp(rows[0].content, "") != 0) {
1876+
fprintf(stderr,
1877+
"[test_llm_chat_double_save] row 0 mismatch (expected system/'%s', "
1878+
"got %s/'%s')\n",
1879+
"", rows[0].role, rows[0].content);
1880+
goto done;
1881+
}
1882+
if (strcmp(rows[1].role, "user") != 0 ||
1883+
strcmp(rows[1].content, prompt1) != 0) {
1884+
fprintf(stderr,
1885+
"[test_llm_chat_double_save] row 0 mismatch (expected user/'%s', "
1886+
"got %s/'%s')\n",
1887+
prompt1, rows[1].role, rows[1].content);
1888+
goto done;
1889+
}
1890+
if (strcmp(rows[2].role, "assistant") != 0) {
1891+
fprintf(stderr,
1892+
"[test_llm_chat_double_save] row 1 mismatch (expected assistant, "
1893+
"got %s)\n",
1894+
rows[2].role);
1895+
goto done;
1896+
}
1897+
if (strcmp(rows[3].role, "user") != 0 ||
1898+
strcmp(rows[3].content, prompt2) != 0) {
1899+
fprintf(stderr,
1900+
"[test_llm_chat_double_save] row 2 mismatch (expected user/'%s', "
1901+
"got %s/'%s')\n",
1902+
prompt2, rows[3].role, rows[3].content);
1903+
goto done;
1904+
}
1905+
if (strcmp(rows[4].role, "assistant") != 0) {
1906+
fprintf(stderr,
1907+
"[test_llm_chat_double_save] row 3 mismatch (expected assistant, "
1908+
"got %s)\n",
1909+
rows[4].role);
1910+
goto done;
1911+
}
1912+
1913+
status = 0;
1914+
1915+
done:
1916+
if (chat_created)
1917+
exec_expect_ok(env, db, "SELECT llm_chat_free();");
1918+
if (context_created)
1919+
exec_expect_ok(env, db, "SELECT llm_context_free();");
1920+
if (model_loaded)
1921+
exec_expect_ok(env, db, "SELECT llm_model_free();");
1922+
if (db)
1923+
sqlite3_close(db);
1924+
if (status == 0)
1925+
status = assert_sqlite_memory_clean("llm_chat_double_save", env);
1926+
return status;
1927+
}
1928+
18151929
static const test_case TESTS[] = {
18161930
{"issue15_llm_chat_without_context", test_issue15_chat_without_context},
18171931
{"llm_chat_respond_repeated", test_llm_chat_respond_repeated},
@@ -1844,6 +1958,7 @@ static const test_case TESTS[] = {
18441958
{"chat_respond_auto_init", test_chat_respond_auto_init},
18451959
{"chat_save_with_metadata", test_chat_save_with_metadata},
18461960
{"text_generate_default_limit", test_text_generate_default_limit},
1961+
{"llm_chat_double_save", test_llm_chat_double_save},
18471962
// Audio / Whisper tests
18481963
{"audio_transcribe_no_model", test_audio_transcribe_no_model},
18491964
{"audio_model_load_invalid_path", test_audio_model_load_invalid_path},

0 commit comments

Comments
 (0)