Skip to content

Commit 9becf93

Browse files
committed
Some bugs fixed
1 parent 0a18105 commit 9becf93

5 files changed

Lines changed: 130 additions & 9 deletions

File tree

src/dbmem-embed.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void dbmem_local_engine_free (dbmem_local_engine_t *engine);
2929

3030
dbmem_remote_engine_t *dbmem_remote_engine_init (void *ctx, const char *provider, const char *model, char err_msg[DBMEM_ERRBUF_SIZE]);
3131
int dbmem_remote_compute_embedding (dbmem_remote_engine_t *engine, const char *text, int text_len, embedding_result_t *result);
32+
int dbmem_remote_engine_set_apikey (dbmem_remote_engine_t *engine, const char *api_key, char err_msg[DBMEM_ERRBUF_SIZE]);
3233
void dbmem_remote_engine_free (dbmem_remote_engine_t *engine);
3334

3435
// Custom provider (always available, defined in sqlite-memory.c)

src/dbmem-rembed.c

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static size_t cacert_len = sizeof(cacert_pem) - 1;
2626

2727
#ifndef DBMEM_OMIT_CURL
2828
static size_t dbmem_remote_receive_data(void *contents, size_t size, size_t nmemb, void *xdata);
29+
static struct curl_slist *dbmem_remote_build_headers (const char *api_key);
2930
#endif
3031

3132
struct dbmem_remote_engine_t {
@@ -67,6 +68,27 @@ struct dbmem_remote_engine_t {
6768
#include <stdbool.h>
6869
#include <stddef.h>
6970

71+
#ifndef DBMEM_OMIT_CURL
72+
static struct curl_slist *dbmem_remote_build_headers (const char *api_key) {
73+
char auth_header[512];
74+
struct curl_slist *headers = NULL;
75+
struct curl_slist *next = NULL;
76+
77+
snprintf(auth_header, sizeof(auth_header), "Authorization: Bearer %s", api_key);
78+
headers = curl_slist_append(headers, auth_header);
79+
if (!headers) return NULL;
80+
81+
next = curl_slist_append(headers, "Content-Type: application/json");
82+
if (!next) {
83+
curl_slist_free_all(headers);
84+
return NULL;
85+
}
86+
headers = next;
87+
88+
return headers;
89+
}
90+
#endif
91+
7092
static bool text_needs_json_escape (const char *text, size_t *len) {
7193
size_t original_len = *len;
7294
size_t required_len = 0;
@@ -263,11 +285,7 @@ dbmem_remote_engine_t *dbmem_remote_engine_init (void *ctx, const char *provider
263285
#endif
264286

265287
// set up headers
266-
char auth_header[512];
267-
snprintf(auth_header, sizeof(auth_header), "Authorization: Bearer %s", api_key);
268-
struct curl_slist *headers = NULL;
269-
headers = curl_slist_append(headers, auth_header);
270-
if (headers) headers = curl_slist_append(headers, "Content-Type: application/json");
288+
struct curl_slist *headers = dbmem_remote_build_headers(api_key);
271289
if (!headers) {
272290
snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to allocate HTTP headers");
273291
curl_easy_cleanup(curl);
@@ -522,6 +540,36 @@ int dbmem_remote_compute_embedding (dbmem_remote_engine_t *engine, const char *t
522540
return 0;
523541
}
524542

543+
int dbmem_remote_engine_set_apikey (dbmem_remote_engine_t *engine, const char *api_key, char err_msg[DBMEM_ERRBUF_SIZE]) {
544+
if (!engine || !api_key) {
545+
if (err_msg) snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Invalid remote engine or API key");
546+
return SQLITE_MISUSE;
547+
}
548+
549+
#ifndef DBMEM_OMIT_CURL
550+
struct curl_slist *headers = dbmem_remote_build_headers(api_key);
551+
if (!headers) {
552+
if (err_msg) snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to allocate HTTP headers");
553+
return SQLITE_NOMEM;
554+
}
555+
556+
curl_easy_setopt(engine->curl, CURLOPT_HTTPHEADER, headers);
557+
if (engine->headers) curl_slist_free_all(engine->headers);
558+
engine->headers = headers;
559+
#else
560+
char *copy = dbmem_strdup(api_key);
561+
if (!copy) {
562+
if (err_msg) snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Unable to duplicate API key (insufficient memory)");
563+
return SQLITE_NOMEM;
564+
}
565+
566+
if (engine->api_key) dbmemory_free(engine->api_key);
567+
engine->api_key = copy;
568+
#endif
569+
570+
return SQLITE_OK;
571+
}
572+
525573
void dbmem_remote_engine_free (dbmem_remote_engine_t *engine) {
526574
if (!engine) return;
527575

src/dbmem-search.c

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <stdio.h>
2020
#include <float.h>
2121
#include <limits.h>
22+
#include <stddef.h>
2223
#include <time.h>
2324

2425
#ifndef SQLITE_CORE
@@ -92,9 +93,15 @@ static bool dbmem_search_column_hash (sqlite3_stmt *vm, int column, uint64_t *ha
9293

9394
// MARK: - UTILS -
9495

96+
static void vMemorySearchCursorReset (vMemorySearchCursor *c) {
97+
if (c->buffer) dbmemory_free(c->buffer);
98+
memset((char *)c + offsetof(vMemorySearchCursor, max_results), 0,
99+
sizeof(*c) - offsetof(vMemorySearchCursor, max_results));
100+
}
101+
95102
int vMemorySearchCursorAllocate (vMemorySearchCursor *c, int entries, bool perform_fts) {
96103
if (entries <= 0) {
97-
memset(c, 0, sizeof(*c));
104+
vMemorySearchCursorReset(c);
98105
c->max_results = entries;
99106
c->perform_fts = perform_fts;
100107
return SQLITE_OK;
@@ -527,7 +534,7 @@ static int vMemorySearchCursorOpen (sqlite3_vtab *pVtab, sqlite3_vtab_cursor **p
527534

528535
static int vMemorySearchCursorClose (sqlite3_vtab_cursor *cur){
529536
vMemorySearchCursor *c = (vMemorySearchCursor *)cur;
530-
if (c->buffer) dbmemory_free(c->buffer);
537+
vMemorySearchCursorReset(c);
531538
dbmemory_free(c);
532539
return SQLITE_OK;
533540
}
@@ -638,6 +645,7 @@ static int vMemorySearchCursorFilter (sqlite3_vtab_cursor *cur, int idxNum, cons
638645
fetch_count = max_results;
639646
}
640647

648+
vMemorySearchCursorReset(c);
641649
if (fetch_count <= 0) {
642650
c->count = 0;
643651
return SQLITE_OK;

src/sqlite-memory.c

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,16 @@ static void dbmem_set_apikey (sqlite3_context *context, int argc, sqlite3_value
10701070

10711071
// retrieve context
10721072
dbmem_context *ctx = (dbmem_context *)sqlite3_user_data(context);
1073-
1073+
1074+
if (ctx->r_engine && !ctx->is_local && !ctx->is_custom) {
1075+
int rc = dbmem_remote_engine_set_apikey(ctx->r_engine, apikey, ctx->error_msg);
1076+
if (rc != SQLITE_OK) {
1077+
dbmemory_free(apikey);
1078+
sqlite3_result_error(context, ctx->error_msg[0] ? ctx->error_msg : "Unable to update remote API key", -1);
1079+
return;
1080+
}
1081+
}
1082+
10741083
if (ctx->api_key) dbmemory_free(ctx->api_key);
10751084
ctx->api_key = apikey;
10761085

@@ -1649,7 +1658,15 @@ static void dbmem_sql_reindex (sqlite3_context *context, int argc, sqlite3_value
16491658
if (rc != SQLITE_OK) break;
16501659

16511660
int step = sqlite3_step(vm);
1652-
if (step != SQLITE_ROW) { sqlite3_finalize(vm); break; }
1661+
if (step == SQLITE_DONE) {
1662+
sqlite3_finalize(vm);
1663+
break;
1664+
}
1665+
if (step != SQLITE_ROW) {
1666+
sqlite3_finalize(vm);
1667+
rc = step;
1668+
break;
1669+
}
16531670

16541671
// Copy row data before finalizing so we can write in the next step
16551672
const char *path_raw = (const char *)sqlite3_column_text(vm, 0);

test/e2e.c

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,52 @@ TEST(memory_search_ranking) {
378378
ASSERT_SQL_OK(db, "SELECT memory_set_option('min_score', 0.7);");
379379
}
380380

381+
TEST(memory_search_statement_reuse) {
382+
sqlite3_stmt *stmt = NULL;
383+
int rc = sqlite3_prepare_v2(db,
384+
"SELECT hash, snippet FROM memory_search(?1, ?2);",
385+
-1, &stmt, NULL);
386+
ASSERT(rc == SQLITE_OK);
387+
388+
rc = sqlite3_bind_text(stmt, 1, "fox", -1, SQLITE_STATIC);
389+
ASSERT(rc == SQLITE_OK);
390+
rc = sqlite3_bind_int(stmt, 2, 5);
391+
ASSERT(rc == SQLITE_OK);
392+
393+
int first_count = 0;
394+
while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
395+
const char *hash = (const char *)sqlite3_column_text(stmt, 0);
396+
const char *snippet = (const char *)sqlite3_column_text(stmt, 1);
397+
ASSERT(hash != NULL && strlen(hash) == DBMEM_HASH_HEX_LEN);
398+
ASSERT(snippet != NULL && strlen(snippet) > 0);
399+
first_count++;
400+
}
401+
ASSERT(rc == SQLITE_DONE);
402+
ASSERT(first_count > 0);
403+
404+
rc = sqlite3_reset(stmt);
405+
ASSERT(rc == SQLITE_OK);
406+
sqlite3_clear_bindings(stmt);
407+
408+
rc = sqlite3_bind_text(stmt, 1, "SQL database engine", -1, SQLITE_STATIC);
409+
ASSERT(rc == SQLITE_OK);
410+
rc = sqlite3_bind_int(stmt, 2, 10);
411+
ASSERT(rc == SQLITE_OK);
412+
413+
int second_count = 0;
414+
while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
415+
const char *hash = (const char *)sqlite3_column_text(stmt, 0);
416+
const char *snippet = (const char *)sqlite3_column_text(stmt, 1);
417+
ASSERT(hash != NULL && strlen(hash) == DBMEM_HASH_HEX_LEN);
418+
ASSERT(snippet != NULL && strlen(snippet) > 0);
419+
second_count++;
420+
}
421+
ASSERT(rc == SQLITE_DONE);
422+
ASSERT(second_count > 0);
423+
424+
sqlite3_finalize(stmt);
425+
}
426+
381427
// ============================================================================
382428
// Phase 5: Deletion
383429
// ============================================================================
@@ -495,6 +541,7 @@ int main(void) {
495541
// Phase 4: Search (network calls)
496542
RUN_TEST(memory_search);
497543
RUN_TEST(memory_search_ranking);
544+
RUN_TEST(memory_search_statement_reuse);
498545

499546
// Phase 5: Deletion
500547
RUN_TEST(memory_delete);

0 commit comments

Comments
 (0)